from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from typing import Literal, List, Union, Dict
from onyxengine.modeling import MLPOptConfig, RNNOptConfig, TransformerOptConfig, validate_param, validate_opt_param
class AdamWConfig(BaseModel):
"""
Configuration for the AdamW optimizer.
Args:
lr (float): Learning rate (default is 3e-4).
weight_decay (float): Weight decay (default is 1e-2).
"""
name: str = Field(default='adamw', frozen=True, init=False)
lr: float = 3e-4
weight_decay: float = 1e-2
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_param(self.lr, 'lr', min_val=0.0)
validate_param(self.weight_decay, 'weight_decay', min_val=0.0)
return self
class AdamWOptConfig(BaseModel):
"""
Optimization config for the AdamW optimizer.
Args:
lr (Union[float, Dict[str, List[float]]): Learning rate (default is {"select": [1e-5, 5e-5, 1e-4, 3e-4, 5e-4, 8e-4, 1e-3, 5e-3, 1e-2]}).
weight_decay (Union[float, Dict[str, List[float]]): Weight decay (default is {"select": [1e-4, 1e-3, 1e-2, 1e-1]}).
"""
name: str = Field(default='adamw_opt', frozen=True, init=False)
lr: Union[float, Dict[str, List[float]]] = {"select": [1e-5, 5e-5, 1e-4, 3e-4, 5e-4, 8e-4, 1e-3, 5e-3, 1e-2]}
weight_decay: Union[float, Dict[str, List[float]]] = {"select": [1e-4, 1e-3, 1e-2, 1e-1]}
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_opt_param(self.lr, 'lr', options=['select', 'range'], min_val=0.0)
validate_opt_param(self.weight_decay, 'weight_decay', options=['select', 'range'], min_val=0.0)
return self
class SGDConfig(BaseModel):
"""
Configuration for the SGD optimizer.
Args:
lr (float): Learning rate (default is 3e-4).
weight_decay (float): Weight decay (default is 1e-2).
momentum (float): Momentum (default is 0.9).
"""
name: str = Field(default='sgd', frozen=True, init=False)
lr: float = 3e-4
weight_decay: float = 1e-2
momentum: float = 0.9
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_param(self.lr, 'lr', min_val=0.0)
validate_param(self.weight_decay, 'weight_decay', min_val=0.0)
validate_param(self.momentum, 'momentum', min_val=0.0, max_val=1.0)
return self
class SGDOptConfig(BaseModel):
"""
Optimization config for the SGD optimizer.
Args:
lr (Union[float, Dict[str, List[float]]): Learning rate (default is {"select": [1e-5, 5e-5, 1e-4, 3e-4, 5e-4, 8e-4, 1e-3, 5e-3, 1e-2]}).
weight_decay (Union[float, Dict[str, List[float]]): Weight decay (default is {"select": [1e-4, 1e-3, 1e-2, 1e-1]}).
momentum (Union[float, Dict[str, List[float]]): Momentum (default is {"select": [0.0, 0.8, 0.9, 0.95, 0.99]}).
"""
name: str = Field(default='sgd_opt', frozen=True, init=False)
lr: Union[float, Dict[str, List[float]]] = {"select": [1e-5, 5e-5, 1e-4, 3e-4, 5e-4, 8e-4, 1e-3, 5e-3, 1e-2]}
weight_decay: Union[float, Dict[str, List[float]]] = {"select": [1e-4, 1e-3, 1e-2, 1e-1]}
momentum: Union[float, Dict[str, List[float]]] = {"select": [0.0, 0.8, 0.9, 0.95, 0.99]}
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_opt_param(self.lr, 'lr', options=['select', 'range'], min_val=0.0)
validate_opt_param(self.weight_decay, 'weight_decay', options=['select', 'range'], min_val=0.0)
validate_opt_param(self.momentum, 'momentum', options=['select', 'range'], min_val=0.0, max_val=1.0)
return self
class CosineDecayWithWarmupConfig(BaseModel):
"""
Configuration for learning rate scheduler with cosine decay and linear warmup.
Args:
max_lr (float): Maximum learning rate (default is 3e-4).
min_lr (float): Minimum learning rate (default is 3e-5).
warmup_iters (int): Number of warmup iterations (default is 200).
decay_iters (int): Number of decay iterations (default is 1000).
"""
name: str = Field(default='cosine_decay_with_warmup', frozen=True, init=False)
max_lr: float = 3e-4
min_lr: float = 3e-5
warmup_iters: int = 200
decay_iters: int = 1000
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_param(self.max_lr, 'max_lr', min_val=0.0)
validate_param(self.min_lr, 'min_lr', min_val=0.0)
validate_param(self.warmup_iters, 'warmup_iters', min_val=0)
validate_param(self.decay_iters, 'decay_iters', min_val=0)
return self
class CosineDecayWithWarmupOptConfig(BaseModel):
"""
Optimization config for learning rate scheduler with cosine decay and linear warmup.
Args:
max_lr (Union[float, Dict[str, List[float]]): Maximum learning rate (default is {"select": [1e-4, 3e-4, 5e-4, 8e-4, 1e-3, 3e-3, 5e-3]}).
min_lr (Union[float, Dict[str, List[float]]): Minimum learning rate (default is {"select": [1e-6, 5e-6, 1e-5, 3e-5, 5e-5, 8e-5, 1e-4]}).
warmup_iters (Union[int, Dict[str, List[int]]): Number of warmup iterations (default is {"select": [50, 100, 200, 400, 800]}).
decay_iters (Union[int, Dict[str, List[int]]): Number of decay iterations (default is {"select": [500, 1000, 2000, 4000, 8000]}).
"""
name: str = Field(default='cosine_decay_with_warmup_opt', frozen=True, init=False)
max_lr: Union[float, Dict[str, List[float]]] = {"select": [1e-4, 3e-4, 5e-4, 8e-4, 1e-3, 3e-3, 5e-3]}
min_lr: Union[float, Dict[str, List[float]]] = {"select": [1e-6, 5e-6, 1e-5, 3e-5, 5e-5, 8e-5, 1e-4]}
warmup_iters: Union[int, Dict[str, List[int]]] = {"select": [50, 100, 200, 400, 800]}
decay_iters: Union[int, Dict[str, List[int]]] = {"select": [500, 1000, 2000, 4000, 8000]}
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_opt_param(self.max_lr, 'max_lr', options=['select', 'range'], min_val=0.0)
validate_opt_param(self.min_lr, 'min_lr', options=['select', 'range'], min_val=0.0)
validate_opt_param(self.warmup_iters, 'warmup_iters', options=['select', 'range'], min_val=0)
validate_opt_param(self.decay_iters, 'decay_iters', options=['select', 'range'], min_val=0)
return self
class CosineAnnealingWarmRestartsConfig(BaseModel):
"""
Configuration for learning rate scheduler with cosine annealing and warm restarts.
Args:
T_0 (int): Initial period of learning rate decay (default is 2000).
T_mult (int): Multiplicative factor for the period of learning rate decay (default is 1).
eta_min (float): Minimum learning rate (default is 3e-5).
"""
name: str = Field(default='cosine_annealing_warm_restarts', frozen=True, init=False)
T_0: int = 2000
T_mult: int = 1
eta_min: float = 3e-5
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_param(self.T_0, 'T_0', min_val=0)
validate_param(self.T_mult, 'T_mult', min_val=0)
validate_param(self.eta_min, 'eta_min', min_val=0.0)
return self
class CosineAnnealingWarmRestartsOptConfig(BaseModel):
"""
Optimization config for learning rate scheduler with cosine annealing and warm restarts.
Args:
T_0 (Union[int, Dict[str, List[int]]]): Initial period of learning rate decay (default is {"select": [200, 500, 1000, 2000, 5000, 10000]}).
T_mult (Union[int, Dict[str, List[int]]]): Multiplicative factor for the period of learning rate decay (default is {"select": [1, 2, 3]}).
eta_min (Union[float, Dict[str, List[float]]]): Minimum learning rate (default is {"select": [1e-6, 5e-6, 1e-5, 3e-5, 5e-5, 8e-5, 1e-4, 3e-4]}).
"""
name: str = Field(default='cosine_annealing_warm_restarts_opt', frozen=True, init=False)
T_0: Union[int, Dict[str, List[int]]] = {"select": [200, 500, 1000, 2000, 5000, 10000]}
T_mult: Union[int, Dict[str, List[int]]] = {"select": [1, 2, 3]}
eta_min: Union[float, Dict[str, List[float]]] = {"select": [1e-6, 5e-6, 1e-5, 3e-5, 5e-5, 8e-5, 1e-4, 3e-4]}
@model_validator(mode='after')
def validate_hyperparameters(self) -> Self:
validate_opt_param(self.T_0, 'T_0', options=['select', 'range'], min_val=0)
validate_opt_param(self.T_mult, 'T_mult', options=['select', 'range'], min_val=0)
validate_opt_param(self.eta_min, 'eta_min', options=['select', 'range'], min_val=0.0)
return self
[docs]
class TrainingConfig(BaseModel):
"""
Configuration for the training of a model.
Args:
training_iters (int): Number of training iterations (default is 3000).
train_batch_size (int): Batch size for training (default is 32).
train_val_split_ratio (float): Ratio of training data to validation data (default is 0.9).
test_dataset_size (int): Number of samples in the test dataset (default is 500).
checkpoint_type (Literal['single_step', 'multi_step']): Type of checkpointing (default is 'single_step').
optimizer (Union[AdamWConfig, SGDConfig]): Optimizer configuration (default is AdamWConfig()).
lr_scheduler (Union[None, CosineDecayWithWarmupConfig, CosineAnnealingWarmRestartsConfig]): Learning rate scheduler configuration (default is None).
"""
training_iters: int = 3000
train_batch_size: int = 32
train_val_split_ratio: float = 0.9
test_dataset_size: int = 500
checkpoint_type: Literal['single_step', 'multi_step'] = 'single_step'
optimizer: Union[AdamWConfig, SGDConfig] = AdamWConfig()
lr_scheduler: Union[None, CosineDecayWithWarmupConfig, CosineAnnealingWarmRestartsConfig] = None
class OptimizationConfig(BaseModel):
"""
Configuration for the optimization of models.
Args:
training_iters (int): Number of training iterations (default is 3000).
train_batch_size (int): Batch size for training (default is 32).
train_val_split_ratio (float): Ratio of training data to validation data (default is 0.9).
test_dataset_size (int): Number of samples in the test dataset (default is 500).
checkpoint_type (Literal['single_step', 'multi_step']): Type of checkpointing (default is 'single_step').
opt_models (List[Union[MLPOptConfig, RNNOptConfig, TransformerOptConfig]]): List of model optimization configurations.
opt_optimizers (List[Union[AdamWOptConfig, SGDOptConfig]]): List of optimizer optimization configurations.
opt_lr_schedulers (List[Union[None, CosineDecayWithWarmupOptConfig, CosineAnnealingWarmRestartsOptConfig]]): List of learning rate scheduler optimization configurations.
num_trials (int): Number of optimization trials (default is 10).
"""
training_iters: int = 3000
train_batch_size: int = 32
train_val_split_ratio: float = 0.9
test_dataset_size: int = 500
checkpoint_type: Literal['single_step', 'multi_step'] = 'single_step'
opt_models: List[Union[MLPOptConfig, RNNOptConfig, TransformerOptConfig]] = []
opt_optimizers: List[Union[AdamWOptConfig, SGDOptConfig]] = []
opt_lr_schedulers: List[Union[None, CosineDecayWithWarmupOptConfig, CosineAnnealingWarmRestartsOptConfig]] = [None]
num_trials: int = 10