onyx.train_model(
model_name: str = "",
model_config: Union[MLPConfig, RNNConfig, TransformerConfig] = None,
dataset_name: str = "",
dataset_version_id: Optional[str] = None,
training_config: TrainingConfig = TrainingConfig(),
)
Trains a model on the Onyx Engine using the specified dataset and configuration.
Parameters
The name for the trained model. Must be a non-empty string.
model_config
Union[MLPConfig, RNNConfig, TransformerConfig]
required
The model architecture configuration.
The name of the dataset to train on. Must be a non-empty string.
The specific dataset version to use. If None, uses the latest version.
training_config
TrainingConfig
default:"TrainingConfig()"
Configuration for the training process including iterations, batch size, optimizer, etc.
Returns
None. Training runs on the Engine and the model is saved automatically upon completion.
Raises
Exception: If model_name or dataset_name is empty
AssertionError: If parameters have incorrect types
Example
from onyxengine import Onyx
from onyxengine.modeling import (
Output, Input, MLPConfig, TrainingConfig, AdamWConfig
)
# Initialize the client
onyx = Onyx()
# Define model structure
outputs = [Output(name='acceleration_predicted')]
inputs = [
Input(name='velocity', parent='acceleration_predicted', relation='derivative'),
Input(name='position', parent='velocity', relation='derivative'),
Input(name='control_input'),
]
# Configure model
model_config = MLPConfig(
outputs=outputs,
inputs=inputs,
dt=0.0025,
sequence_length=8,
hidden_layers=3,
hidden_size=64,
activation='relu',
dropout=0.2
)
# Configure training
training_config = TrainingConfig(
training_iters=2000,
train_batch_size=1024,
checkpoint_type='single_step',
optimizer=AdamWConfig(lr=3e-4, weight_decay=1e-2)
)
# Start training
onyx.train_model(
model_name='example_model',
model_config=model_config,
dataset_name='example_train_data',
training_config=training_config,
)
Notes
- Training runs on GPU-accelerated infrastructure
- The trained model is automatically saved to the Engine
- Load the trained model with
onyx.load_model('example_model')
- Monitor detailed progress in the Engine Platform