[docs]classState(BaseModel):""" State variable used in ModelSimulatorConfig. Args: name (str): Variable name. relation (Literal['output', 'delta', 'derivative']): Method to solve for the variable: variable is an output of the model, parent is the delta of the variable, or parent is the derivative of the variable. parent (str): Parent variable to derive from. """name:strrelation:Literal['output','delta','derivative']parent:str
[docs]classModelSimulatorConfig(BaseModel):""" Configuration class for the model simulator. Args: outputs (List[str]): List of output variables. states (List[State]): List of state variables. controls (List[str]): List of control variables. dt (float): Time step for simulation. """outputs:List[str]=[]states:List[State]=[]controls:List[str]=[]dt:float=0@propertydefnum_outputs(self):returnlen(self.outputs)@propertydefnum_states(self):returnlen(self.states)@propertydefnum_controls(self):returnlen(self.controls)@propertydefnum_inputs(self):returnself.num_states+self.num_controls
classModelSimulator():def__init__(self,sim_config:ModelSimulatorConfig):# Separate variables by dependency the reorder to ensure parents are computed firstoutput_dep_vars,state_dep_vars=self._separate_variables(sim_config)self.output_dep_vars=self._resolve_variable_order(output_dep_vars)self.state_dep_vars=self._resolve_variable_order(state_dep_vars)self.dt=sim_config.dtself.amp_context=nullcontext()self.n_state=len(sim_config.states)self.n_control=len(sim_config.controls)defsimulator_mixed_precision(self,setting:bool):ifsetting==False:self.amp_context=nullcontext()returndevice='cuda'ifnext(self.parameters()).is_cudaelse'cpu'torch.backends.cuda.matmul.allow_tf32=True# Allow tf32 on matmultorch.backends.cudnn.allow_tf32=True# Allow tf32 on cudnnamp_dtype=torch.bfloat16iftorch.cuda.is_available()andtorch.cuda.is_bf16_supported()elsetorch.float16self.amp_context=torch.amp.autocast(device_type=device,dtype=amp_dtype)def_separate_variables(self,sim_config:ModelSimulatorConfig):# Separate variables by output dependent and state dependentoutput_names=sim_config.outputsstate_names=[state.nameforstateinsim_config.states]output_dep_vars=[]state_dep_vars=[]foridx,varinenumerate(sim_config.states):ifvar.parentinoutput_names:output_dep_vars.append((var,idx,output_names.index(var.parent)))else:state_dep_vars.append((var,idx,state_names.index(var.parent)))returnoutput_dep_vars,state_dep_varsdef_resolve_variable_order(self,var_tuple_list):# Sort variables based on parent-child dependenciesvar_map={var.name:varforvar,_,_invar_tuple_list}defsort_key(sim_variable_tuple):var,idx,parent_idx=sim_variable_tuple# Output variables get highest priority, then resolve parent-child dependenciespriority=0ifvar.relation=='output'else1parent_depth=0parent=var.parentwhileparent:parent_depth+=1parent=var_map[parent].parentifparentinvar_mapelseNonereturn(priority,parent_depth)returnsorted(var_tuple_list,key=sort_key)def_step(self,x,output_next=None):# Do a single forward step of the model and update the state variablesdx=self.forward(x[:,:-1,:])ifoutput_nextisnotNone:output_next=dx# Update state variables depending on model output dxforvar,var_idx,parent_idxinself.output_dep_vars:ifvar.relation=='output':x[:,-1,var_idx]=dx[:,parent_idx]elifvar.relation=='delta':x[:,-1,var_idx]=x[:,-2,var_idx]+dx[:,parent_idx]elifvar.relation=='derivative':x[:,-1,var_idx]=x[:,-2,var_idx]+dx[:,parent_idx]*self.dt# Update state variables depending on state xforvar,var_idx,parent_idxinself.state_dep_vars:ifvar.relation=='delta':x[:,-1,var_idx]=x[:,-2,var_idx]+x[:,-1,parent_idx]elifvar.relation=='derivative':x[:,-1,var_idx]=x[:,-2,var_idx]+x[:,-1,parent_idx]*self.dtdefsimulate(self,traj_solution,x0=None,u=None,output_traj=None):# Fills in the values of the state variables in traj_solution,# traj_solution (batch_size, sequence_length+sim_steps, num_states+num_control)# x0 - initial state (batch_size, sequence_length, num_states)# u - control inputs (batch_size, sim_steps+sequence_length, num_control)# output_traj - model outputs to fill in too (batch_size, sim_steps, num_outputs))# Initialize simulation dataseq_length=self.config.sequence_lengthsim_steps=traj_solution.size(1)-seq_lengthifx0isnotNone:traj_solution[:,:seq_length,:self.n_state]=x0ifuisnotNone:traj_solution[:,:,-self.n_control:]=uwithself.amp_contextandtorch.no_grad():foriinrange(sim_steps):# Pass in the trajectory up to the t+1 stepifoutput_trajisnotNone:self._step(traj_solution[:,i:i+seq_length+1,:],output_traj[:,i,:])else:self._step(traj_solution[:,i:i+seq_length+1,:])