rl8.models package

Module contents

Definitions related to parameterizations of policies.

Models are intended to be called with their respective forward pass (like any other PyTorch module) to get the inputs to the policy’s action distribution (along with other data depending on the type of model being used). Models are expected to store their value function approximations after each forward pass in some intermediate attribute so it can be accessed with a subsequent call to a value_function method.

Models are largely inspired by RLlib’s model concept.

class rl8.models.DefaultContinuousModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: UnboundedContinuousTensorSpec, /, *, hiddens: Sequence[int] = (256, 256), activation_fn: str = 'relu', bias: bool = True)[source]

Bases: GenericModel[UnboundedContinuousTensorSpec, UnboundedContinuousTensorSpec]

Default model for 1D continuous observations and action spaces.

latent_model: Sequential

Transform observations to inputs for output heads.

action_mean: Linear

Output head for action mean for a normal distribution.

action_log_std: Linear

Output head for action log std for a normal distribution.

vf_model: Sequential

Value function model, independent of action params.

forward(batch: TensorDict, /) TensorDict[source]

Process a batch of tensors and return features to be fed into an action distribution.

Parameters:

batch – A tensordict expected to have at least an "obs" key with any tensor spec. The policy that the model is a component of processes the batch according to Model.view_requirements prior to passing the batch to the forward pass. The tensordict must have a 1D batch shape like [B, ...].

Returns:

Features that will be passed to an action distribution with batch shape like [B, ...].

to(device: str | device) Self[source]

Helper for changing the device the model is on.

The specs associated with the model aren’t updated with the PyTorch module’s to method since they aren’t PyTorch modules themselves.

Parameters:

device – Target device.

Returns:

The updated model.

value_function() Tensor[source]

Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.

This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.

class rl8.models.DefaultDiscreteModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: DiscreteTensorSpec, /, *, hiddens: Sequence[int] = (256, 256), activation_fn: str = 'relu', bias: bool = True)[source]

Bases: GenericModel[UnboundedContinuousTensorSpec, DiscreteTensorSpec]

Default model for 1D continuous observations and discrete action spaces.

feature_model: Sequential

Transform observations to features for action distributions.

vf_model: Sequential

Value function model, independent of action params.

forward(batch: TensorDict, /) TensorDict[source]

Process a batch of tensors and return features to be fed into an action distribution.

Parameters:

batch – A tensordict expected to have at least an "obs" key with any tensor spec. The policy that the model is a component of processes the batch according to Model.view_requirements prior to passing the batch to the forward pass. The tensordict must have a 1D batch shape like [B, ...].

Returns:

Features that will be passed to an action distribution with batch shape like [B, ...].

to(device: str | device) Self[source]

Helper for changing the device the model is on.

The specs associated with the model aren’t updated with the PyTorch module’s to method since they aren’t PyTorch modules themselves.

Parameters:

device – Target device.

Returns:

The updated model.

value_function() Tensor[source]

Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.

This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.

class rl8.models.GenericModel(observation_spec: _ObservationSpec, action_spec: _ActionSpec, /, **config: Any)[source]

Bases: Model, Generic[_ObservationSpec, _ActionSpec]

Generic model for constructing models from fixed observation and action specs.

action_spec: _ActionSpec

Action space campatible with the model.

observation_spec: _ObservationSpec

Observation space compatible with the model.

class rl8.models.GenericModelBase(observation_spec: TensorSpec, action_spec: TensorSpec, /, **config: Any)[source]

Bases: Module[_P, _T]

The base for the policy component that processes environment observations into a value function approximation and features to be consumed by an action distribution for action sampling.

All model flavors (i.e., feedforward or recurrent) inherit from this base. This just provides a common interface for different model flavors.

Parameters:
  • observation_spec – Spec defining the forward pass input.

  • action_spec – Spec defining the outputs of the policy’s action distribution that this model is a component of.

  • config – Model-specific configuration.

observation_spec: TensorSpec

Spec defining the forward pass input. Useful for validating the forward pass and for defining the model as a function of the observation spec.

action_spec: TensorSpec

Spec defining the outputs of the policy’s action distribution that this model is a component of. Useful for defining the model as a function of the action spec.

config: dict[str, Any]

Model-specific configuration. Passed from the policy and algorithm.

property device: str | device

Return the device the model is currently on.

abstract to(device: str | device) Self[source]

Helper for changing the device the model is on.

The specs associated with the model aren’t updated with the PyTorch module’s to method since they aren’t PyTorch modules themselves.

Parameters:

device – Target device.

Returns:

The updated model.

abstract value_function() Tensor[source]

Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.

This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.

class rl8.models.Model(observation_spec: TensorSpec, action_spec: TensorSpec, /, **config: Any)[source]

Bases: GenericModelBase[(<class ‘tensordict._td.TensorDict’>,), TensorDict]

Feedforward policy component that processes environment observations into a value function approximation and features to be consumed by an action distribution for action sampling.

Parameters:
  • observation_spec – Spec defining the forward pass input.

  • action_spec – Spec defining the outputs of the policy’s action distribution that this model is a component of.

  • config – Model-specific configuration.

view_requirements: dict[str, rl8.views.ViewRequirement]

Requirements on how a tensor batch should be preprocessed by the policy prior to being passed to the forward pass. Useful for handling sequence shifting or masking so you don’t have to. By default, observations are passed with no shifting. This should be overwritten in a model’s __init__ for custom view requirements.

apply_view_requirements(batch: TensorDict, /, *, kind: Literal['last', 'all'] = 'last') TensorDict[source]

Apply the model’s view requirements, reshaping tensors as-needed.

This is usually called by the policy that the model is a component of, but can be used within the model if the model is deployed without the policy or action distribution.

Parameters:
  • batch – Batch to feed into the policy’s underlying model. Expected to be of size [B, T, ...] where B is the batch dimension, and T is the time or sequence dimension. B is typically the number of parallel environments being sampled for during massively parallel training, and T is typically the number of time steps or observations sampled from the environments. The B and T dimensions are typically combined into one dimension during application of the view requirements.

  • kind

    String indicating the type of view requirements to apply. The model’s view requirements are applied slightly differently depending on the value. Options include:

    • ”last”: Apply the view requirements using only the samples necessary to sample for the most recent observations within the batch’s T dimension.

    • ”all”: Sample from batch using all observations within the batch’s T dimension. Expand the B and T dimensions together.

static default_model_cls(observation_spec: TensorSpec, action_spec: TensorSpec, /) type['Model'][source]

Return a default model class based on the given observation and action specs.

Parameters:
  • observation_spec – Environment observation spec.

  • action_spec – Environment action spec.

Returns:

A default model class.

property drop_size: int

Return the model’s drop size (also the drop size for all view requirements).

abstract forward(batch: TensorDict, /) TensorDict[source]

Process a batch of tensors and return features to be fed into an action distribution.

Parameters:

batch – A tensordict expected to have at least an "obs" key with any tensor spec. The policy that the model is a component of processes the batch according to Model.view_requirements prior to passing the batch to the forward pass. The tensordict must have a 1D batch shape like [B, ...].

Returns:

Features that will be passed to an action distribution with batch shape like [B, ...].

to(device: str | device) Self[source]

Helper for changing the device the model is on.

The specs associated with the model aren’t updated with the PyTorch module’s to method since they aren’t PyTorch modules themselves.

Parameters:

device – Target device.

Returns:

The updated model.

validate_view_requirements() None[source]

Helper for validating a model’s view requirements.

Raises:

RuntimeError – If the model’s view requirements result in an ambiguous batch size, making training and sampling impossible.

class rl8.models.ModelFactory(*args, **kwargs)[source]

Bases: Protocol

Factory protocol describing how to create a model instance.

class rl8.models.DefaultContinuousRecurrentModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: UnboundedContinuousTensorSpec, /, *, hidden_size: int = 256, num_layers: int = 1, bias: bool = True)[source]

Bases: GenericRecurrentModel[UnboundedContinuousTensorSpec, UnboundedContinuousTensorSpec]

Default recurrent model for 1D continuous observations and action spaces.

lstm: LSTM

Transform observations to inputs for output heads.

action_mean: Linear

Output head for action mean for a normal distribution.

action_log_std: Linear

Output head for action log std for a normal distribution.

vf_model: Linear

Value function model, independent of action params.

forward(batch: TensorDict, states: TensorDict, /) tuple[tensordict._td.TensorDict, tensordict._td.TensorDict][source]

Process a batch of tensors and return features to be fed into an action distribution.

Both input arguments are expected to have a 2D batch shape like [B, T, ...] where B is the batch number (or typically the number of parallel environments) and T is the sequence length.

Parameters:
  • batch – A tensordict expected to have at least an "obs" key with any tensor spec.

  • states – A tensordict that contains the recurrent states for the model and has spec equal to RecurrentModel.state_spec

Returns:

Features that will be passed to an action distribution and updated recurrent states. The features are expected to have batch shape like [B * T, ...] while the updated recurrent states are expected to have batch shape like [B, ...]. In other words, the batch and sequence dimension of the input arguments are flattened together for the output features while the returned recurrent states maintain the original batch dimension but don’t have a sequence dimension.

to(device: str | device) Self[source]

Helper for changing the device the model is on.

The specs associated with the model aren’t updated with the PyTorch module’s to method since they aren’t PyTorch modules themselves.

Parameters:

device – Target device.

Returns:

The updated model.

value_function() Tensor[source]

Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.

This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.

class rl8.models.DefaultDiscreteRecurrentModel(observation_spec: UnboundedContinuousTensorSpec, action_spec: DiscreteTensorSpec, /, *, hidden_size: int = 256, num_layers: int = 1, bias: bool = True)[source]

Bases: GenericRecurrentModel[UnboundedContinuousTensorSpec, DiscreteTensorSpec]

Default recurrent model for 1D continuous observations and discrete action spaces.

lstm: LSTM
feature_head: Linear

Transform observations to features for action distributions.

vf_head: Linear

Value function model, independent of action params.

forward(batch: TensorDict, states: TensorDict, /) tuple[tensordict._td.TensorDict, tensordict._td.TensorDict][source]

Process a batch of tensors and return features to be fed into an action distribution.

Both input arguments are expected to have a 2D batch shape like [B, T, ...] where B is the batch number (or typically the number of parallel environments) and T is the sequence length.

Parameters:
  • batch – A tensordict expected to have at least an "obs" key with any tensor spec.

  • states – A tensordict that contains the recurrent states for the model and has spec equal to RecurrentModel.state_spec

Returns:

Features that will be passed to an action distribution and updated recurrent states. The features are expected to have batch shape like [B * T, ...] while the updated recurrent states are expected to have batch shape like [B, ...]. In other words, the batch and sequence dimension of the input arguments are flattened together for the output features while the returned recurrent states maintain the original batch dimension but don’t have a sequence dimension.

to(device: str | device) Self[source]

Helper for changing the device the model is on.

The specs associated with the model aren’t updated with the PyTorch module’s to method since they aren’t PyTorch modules themselves.

Parameters:

device – Target device.

Returns:

The updated model.

value_function() Tensor[source]

Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything.

This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters.

class rl8.models.GenericRecurrentModel(observation_spec: _ObservationSpec, action_spec: _ActionSpec, /, **config: Any)[source]

Bases: RecurrentModel, Generic[_ObservationSpec, _ActionSpec]

Generic model for constructing models from fixed observation and action specs.

action_spec: _ActionSpec

Action space campatible with the model.

observation_spec: _ObservationSpec

Observation space compatible with the model.

class rl8.models.RecurrentModel(observation_spec: TensorSpec, action_spec: TensorSpec, /, **config: Any)[source]

Bases: GenericModelBase[(<class ‘tensordict._td.TensorDict’>, <class ‘tensordict._td.TensorDict’>), tuple[TensorDict, TensorDict]]

Recurrent policy component that processes environment observations and recurrent model states into a value function approximation, features to be consumed by an action distribution for action sampling, and updated recurrent model states to be used for subsequent calls.

Parameters:
  • observation_spec – Spec defining the forward pass input.

  • action_spec – Spec defining the outputs of the policy’s action distribution that this model is a component of.

  • config – Model-specific configuration.

state_spec: CompositeSpec

Spec defining recurrent model states part of the forward pass input and output. This is expected to be defined in a model’s __init__.

static default_model_cls(observation_spec: TensorSpec, action_spec: TensorSpec, /) type['RecurrentModel'][source]

Return a default model class based on the given observation and action specs.

Parameters:
  • observation_spec – Environment observation spec.

  • action_spec – Environment action spec.

Returns:

A default model class.

abstract forward(batch: TensorDict, states: TensorDict, /) tuple[tensordict._td.TensorDict, tensordict._td.TensorDict][source]

Process a batch of tensors and return features to be fed into an action distribution.

Both input arguments are expected to have a 2D batch shape like [B, T, ...] where B is the batch number (or typically the number of parallel environments) and T is the sequence length.

Parameters:
  • batch – A tensordict expected to have at least an "obs" key with any tensor spec.

  • states – A tensordict that contains the recurrent states for the model and has spec equal to RecurrentModel.state_spec

Returns:

Features that will be passed to an action distribution and updated recurrent states. The features are expected to have batch shape like [B * T, ...] while the updated recurrent states are expected to have batch shape like [B, ...]. In other words, the batch and sequence dimension of the input arguments are flattened together for the output features while the returned recurrent states maintain the original batch dimension but don’t have a sequence dimension.

init_states(n: int, /) TensorDict[source]

Return initial recurrent states for the model.

Override this to make your own method for initializing recurrent states.

Parameters:

n – Batch size to generate initial recurrent states for. This is typically the number of environments being stepped in parallel.

Returns:

Recurrent model states that initialize a recurrent sequence.

to(device: str | device) Self[source]

Helper for changing the device the model is on.

The specs associated with the model aren’t updated with the PyTorch module’s to method since they aren’t PyTorch modules themselves.

Parameters:

device – Target device.

Returns:

The updated model.

class rl8.models.RecurrentModelFactory(*args, **kwargs)[source]

Bases: Protocol

Factory protocol describing how to create a model instance.