rl8.models package
Module contents
Definitions related to parameterizations of policies.
Models are intended to be called with their respective forward pass (like any
other PyTorch module) to get the inputs to the policy’s action
distribution (along with other data depending on the type of model being used).
Models are expected to store their value function approximations after each
forward pass in some intermediate attribute so it can be accessed with a
subsequent call to a value_function
method.
Models are largely inspired by RLlib’s model concept.
- class rl8.models.DefaultContinuousModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: UnboundedContinuousTensorSpec, /, *, hiddens: Sequence[int] = (256, 256), activation_fn: str = 'relu', bias: bool = True)[source]
Bases:
GenericModel
[UnboundedContinuousTensorSpec
,UnboundedContinuousTensorSpec
]Default model for 1D continuous observations and action spaces.
- latent_model: Sequential
Transform observations to inputs for output heads.
- vf_model: Sequential
Value function model, independent of action params.
- forward(batch: TensorDict, /) TensorDict [source]
Process a batch of tensors and return features to be fed into an action distribution.
- Parameters:
batch – A tensordict expected to have at least an
"obs"
key with any tensor spec. The policy that the model is a component of processes the batch according toModel.view_requirements
prior to passing the batch to the forward pass. The tensordict must have a 1D batch shape like[B, ...]
.- Returns:
Features that will be passed to an action distribution with batch shape like
[B, ...]
.
- to(device: str | device) Self [source]
Helper for changing the device the model is on.
The specs associated with the model aren’t updated with the PyTorch module’s
to
method since they aren’t PyTorch modules themselves.- Parameters:
device – Target device.
- Returns:
The updated model.
- value_function() Tensor [source]
Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.
This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.
- class rl8.models.DefaultDiscreteModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: DiscreteTensorSpec, /, *, hiddens: Sequence[int] = (256, 256), activation_fn: str = 'relu', bias: bool = True)[source]
Bases:
GenericModel
[UnboundedContinuousTensorSpec
,DiscreteTensorSpec
]Default model for 1D continuous observations and discrete action spaces.
- feature_model: Sequential
Transform observations to features for action distributions.
- vf_model: Sequential
Value function model, independent of action params.
- forward(batch: TensorDict, /) TensorDict [source]
Process a batch of tensors and return features to be fed into an action distribution.
- Parameters:
batch – A tensordict expected to have at least an
"obs"
key with any tensor spec. The policy that the model is a component of processes the batch according toModel.view_requirements
prior to passing the batch to the forward pass. The tensordict must have a 1D batch shape like[B, ...]
.- Returns:
Features that will be passed to an action distribution with batch shape like
[B, ...]
.
- to(device: str | device) Self [source]
Helper for changing the device the model is on.
The specs associated with the model aren’t updated with the PyTorch module’s
to
method since they aren’t PyTorch modules themselves.- Parameters:
device – Target device.
- Returns:
The updated model.
- value_function() Tensor [source]
Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.
This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.
- class rl8.models.GenericModel(observation_spec: _ObservationSpec, action_spec: _ActionSpec, /, **config: Any)[source]
Bases:
Model
,Generic
[_ObservationSpec
,_ActionSpec
]Generic model for constructing models from fixed observation and action specs.
- action_spec: _ActionSpec
Action space campatible with the model.
- observation_spec: _ObservationSpec
Observation space compatible with the model.
- class rl8.models.GenericModelBase(observation_spec: TensorSpec, action_spec: TensorSpec, /, **config: Any)[source]
Bases:
Module
[_P
,_T
]The base for the policy component that processes environment observations into a value function approximation and features to be consumed by an action distribution for action sampling.
All model flavors (i.e., feedforward or recurrent) inherit from this base. This just provides a common interface for different model flavors.
- Parameters:
observation_spec – Spec defining the forward pass input.
action_spec – Spec defining the outputs of the policy’s action distribution that this model is a component of.
config – Model-specific configuration.
- observation_spec: TensorSpec
Spec defining the forward pass input. Useful for validating the forward pass and for defining the model as a function of the observation spec.
- action_spec: TensorSpec
Spec defining the outputs of the policy’s action distribution that this model is a component of. Useful for defining the model as a function of the action spec.
- abstract to(device: str | device) Self [source]
Helper for changing the device the model is on.
The specs associated with the model aren’t updated with the PyTorch module’s
to
method since they aren’t PyTorch modules themselves.- Parameters:
device – Target device.
- Returns:
The updated model.
- abstract value_function() Tensor [source]
Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.
This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.
- class rl8.models.Model(observation_spec: TensorSpec, action_spec: TensorSpec, /, **config: Any)[source]
Bases:
GenericModelBase
[(<class ‘tensordict._td.TensorDict’>,),TensorDict
]Feedforward policy component that processes environment observations into a value function approximation and features to be consumed by an action distribution for action sampling.
- Parameters:
observation_spec – Spec defining the forward pass input.
action_spec – Spec defining the outputs of the policy’s action distribution that this model is a component of.
config – Model-specific configuration.
- view_requirements: dict[str, rl8.views.ViewRequirement]
Requirements on how a tensor batch should be preprocessed by the policy prior to being passed to the forward pass. Useful for handling sequence shifting or masking so you don’t have to. By default, observations are passed with no shifting. This should be overwritten in a model’s
__init__
for custom view requirements.
- apply_view_requirements(batch: TensorDict, /, *, kind: Literal['last', 'all'] = 'last') TensorDict [source]
Apply the model’s view requirements, reshaping tensors as-needed.
This is usually called by the policy that the model is a component of, but can be used within the model if the model is deployed without the policy or action distribution.
- Parameters:
batch – Batch to feed into the policy’s underlying model. Expected to be of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments being sampled for during massively parallel training, andT
is typically the number of time steps or observations sampled from the environments. TheB
andT
dimensions are typically combined into one dimension during application of the view requirements.kind –
String indicating the type of view requirements to apply. The model’s view requirements are applied slightly differently depending on the value. Options include:
”last”: Apply the view requirements using only the samples necessary to sample for the most recent observations within the
batch
’sT
dimension.”all”: Sample from
batch
using all observations within thebatch
’sT
dimension. Expand theB
andT
dimensions together.
- static default_model_cls(observation_spec: TensorSpec, action_spec: TensorSpec, /) type['Model'] [source]
Return a default model class based on the given observation and action specs.
- Parameters:
observation_spec – Environment observation spec.
action_spec – Environment action spec.
- Returns:
A default model class.
- property drop_size: int
Return the model’s drop size (also the drop size for all view requirements).
- abstract forward(batch: TensorDict, /) TensorDict [source]
Process a batch of tensors and return features to be fed into an action distribution.
- Parameters:
batch – A tensordict expected to have at least an
"obs"
key with any tensor spec. The policy that the model is a component of processes the batch according toModel.view_requirements
prior to passing the batch to the forward pass. The tensordict must have a 1D batch shape like[B, ...]
.- Returns:
Features that will be passed to an action distribution with batch shape like
[B, ...]
.
- to(device: str | device) Self [source]
Helper for changing the device the model is on.
The specs associated with the model aren’t updated with the PyTorch module’s
to
method since they aren’t PyTorch modules themselves.- Parameters:
device – Target device.
- Returns:
The updated model.
- validate_view_requirements() None [source]
Helper for validating a model’s view requirements.
- Raises:
RuntimeError – If the model’s view requirements result in an ambiguous batch size, making training and sampling impossible.
- class rl8.models.ModelFactory(*args, **kwargs)[source]
Bases:
Protocol
Factory protocol describing how to create a model instance.
- class rl8.models.DefaultContinuousRecurrentModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: UnboundedContinuousTensorSpec, /, *, hidden_size: int = 256, num_layers: int = 1, bias: bool = True)[source]
Bases:
GenericRecurrentModel
[UnboundedContinuousTensorSpec
,UnboundedContinuousTensorSpec
]Default recurrent model for 1D continuous observations and action spaces.
- forward(batch: TensorDict, states: TensorDict, /) tuple[tensordict._td.TensorDict, tensordict._td.TensorDict] [source]
Process a batch of tensors and return features to be fed into an action distribution.
Both input arguments are expected to have a 2D batch shape like
[B, T, ...]
whereB
is the batch number (or typically the number of parallel environments) andT
is the sequence length.- Parameters:
batch – A tensordict expected to have at least an
"obs"
key with any tensor spec.states – A tensordict that contains the recurrent states for the model and has spec equal to
RecurrentModel.state_spec
- Returns:
Features that will be passed to an action distribution and updated recurrent states. The features are expected to have batch shape like
[B * T, ...]
while the updated recurrent states are expected to have batch shape like[B, ...]
. In other words, the batch and sequence dimension of the input arguments are flattened together for the output features while the returned recurrent states maintain the original batch dimension but don’t have a sequence dimension.
- to(device: str | device) Self [source]
Helper for changing the device the model is on.
The specs associated with the model aren’t updated with the PyTorch module’s
to
method since they aren’t PyTorch modules themselves.- Parameters:
device – Target device.
- Returns:
The updated model.
- value_function() Tensor [source]
Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.
This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.
- class rl8.models.DefaultDiscreteRecurrentModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: DiscreteTensorSpec, /, *, hidden_size: int = 256, num_layers: int = 1, bias: bool = True)[source]
Bases:
GenericRecurrentModel
[UnboundedContinuousTensorSpec
,DiscreteTensorSpec
]Default recurrent model for 1D continuous observations and discrete action spaces.
- forward(batch: TensorDict, states: TensorDict, /) tuple[tensordict._td.TensorDict, tensordict._td.TensorDict] [source]
Process a batch of tensors and return features to be fed into an action distribution.
Both input arguments are expected to have a 2D batch shape like
[B, T, ...]
whereB
is the batch number (or typically the number of parallel environments) andT
is the sequence length.- Parameters:
batch – A tensordict expected to have at least an
"obs"
key with any tensor spec.states – A tensordict that contains the recurrent states for the model and has spec equal to
RecurrentModel.state_spec
- Returns:
Features that will be passed to an action distribution and updated recurrent states. The features are expected to have batch shape like
[B * T, ...]
while the updated recurrent states are expected to have batch shape like[B, ...]
. In other words, the batch and sequence dimension of the input arguments are flattened together for the output features while the returned recurrent states maintain the original batch dimension but don’t have a sequence dimension.
- to(device: str | device) Self [source]
Helper for changing the device the model is on.
The specs associated with the model aren’t updated with the PyTorch module’s
to
method since they aren’t PyTorch modules themselves.- Parameters:
device – Target device.
- Returns:
The updated model.
- value_function() Tensor [source]
Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.
This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.
- class rl8.models.GenericRecurrentModel(observation_spec: _ObservationSpec, action_spec: _ActionSpec, /, **config: Any)[source]
Bases:
RecurrentModel
,Generic
[_ObservationSpec
,_ActionSpec
]Generic model for constructing models from fixed observation and action specs.
- action_spec: _ActionSpec
Action space campatible with the model.
- observation_spec: _ObservationSpec
Observation space compatible with the model.
- class rl8.models.RecurrentModel(observation_spec: TensorSpec, action_spec: TensorSpec, /, **config: Any)[source]
Bases:
GenericModelBase
[(<class ‘tensordict._td.TensorDict’>, <class ‘tensordict._td.TensorDict’>),tuple
[TensorDict
,TensorDict
]]Recurrent policy component that processes environment observations and recurrent model states into a value function approximation, features to be consumed by an action distribution for action sampling, and updated recurrent model states to be used for subsequent calls.
- Parameters:
observation_spec – Spec defining the forward pass input.
action_spec – Spec defining the outputs of the policy’s action distribution that this model is a component of.
config – Model-specific configuration.
- state_spec: CompositeSpec
Spec defining recurrent model states part of the forward pass input and output. This is expected to be defined in a model’s
__init__
.
- static default_model_cls(observation_spec: TensorSpec, action_spec: TensorSpec, /) type['RecurrentModel'] [source]
Return a default model class based on the given observation and action specs.
- Parameters:
observation_spec – Environment observation spec.
action_spec – Environment action spec.
- Returns:
A default model class.
- abstract forward(batch: TensorDict, states: TensorDict, /) tuple[tensordict._td.TensorDict, tensordict._td.TensorDict] [source]
Process a batch of tensors and return features to be fed into an action distribution.
Both input arguments are expected to have a 2D batch shape like
[B, T, ...]
whereB
is the batch number (or typically the number of parallel environments) andT
is the sequence length.- Parameters:
batch – A tensordict expected to have at least an
"obs"
key with any tensor spec.states – A tensordict that contains the recurrent states for the model and has spec equal to
RecurrentModel.state_spec
- Returns:
Features that will be passed to an action distribution and updated recurrent states. The features are expected to have batch shape like
[B * T, ...]
while the updated recurrent states are expected to have batch shape like[B, ...]
. In other words, the batch and sequence dimension of the input arguments are flattened together for the output features while the returned recurrent states maintain the original batch dimension but don’t have a sequence dimension.
- init_states(n: int, /) TensorDict [source]
Return initial recurrent states for the model.
Override this to make your own method for initializing recurrent states.
- Parameters:
n – Batch size to generate initial recurrent states for. This is typically the number of environments being stepped in parallel.
- Returns:
Recurrent model states that initialize a recurrent sequence.