rl8 package
Subpackages
- rl8.algorithms package
- Module contents
Algorithm
GenericAlgorithmBase
GenericAlgorithmBase.buffer
GenericAlgorithmBase.buffer_spec
GenericAlgorithmBase.entropy_scheduler
GenericAlgorithmBase.env
GenericAlgorithmBase.grad_scaler
GenericAlgorithmBase.hparams
GenericAlgorithmBase.lr_scheduler
GenericAlgorithmBase.optimizer
GenericAlgorithmBase.policy
GenericAlgorithmBase.state
GenericAlgorithmBase.collect()
GenericAlgorithmBase.horizons_per_env_reset
GenericAlgorithmBase.memory_stats()
GenericAlgorithmBase.params
GenericAlgorithmBase.step()
RecurrentAlgorithm
- Module contents
- rl8.models package
- rl8.nn package
- rl8.policies package
- rl8.trainers package
- Submodules
- rl8.trainers.config module
TrainConfig
TrainConfig.env_cls
TrainConfig.env_config
TrainConfig.model_cls
TrainConfig.model_config
TrainConfig.distribution_cls
TrainConfig.horizon
TrainConfig.horizons_per_env_reset
TrainConfig.num_envs
TrainConfig.seq_len
TrainConfig.seqs_per_state_reset
TrainConfig.optimizer_cls
TrainConfig.optimizer_config
TrainConfig.accumulate_grads
TrainConfig.enable_amp
TrainConfig.entropy_coeff
TrainConfig.gae_lambda
TrainConfig.gamma
TrainConfig.sgd_minibatch_size
TrainConfig.num_sgd_iters
TrainConfig.shuffle_minibatches
TrainConfig.clip_param
TrainConfig.vf_clip_param
TrainConfig.dual_clip_param
TrainConfig.vf_coeff
TrainConfig.target_kl_div
TrainConfig.max_grad_norm
TrainConfig.normalize_advantages
TrainConfig.normalize_rewards
TrainConfig.device
TrainConfig.recurrent
TrainConfig.build()
TrainConfig.from_file()
- Module contents
GenericTrainerBase
RecurrentTrainer
TrainConfig
TrainConfig.env_cls
TrainConfig.env_config
TrainConfig.model_cls
TrainConfig.model_config
TrainConfig.distribution_cls
TrainConfig.horizon
TrainConfig.horizons_per_env_reset
TrainConfig.num_envs
TrainConfig.seq_len
TrainConfig.seqs_per_state_reset
TrainConfig.optimizer_cls
TrainConfig.optimizer_config
TrainConfig.accumulate_grads
TrainConfig.enable_amp
TrainConfig.entropy_coeff
TrainConfig.gae_lambda
TrainConfig.gamma
TrainConfig.sgd_minibatch_size
TrainConfig.num_sgd_iters
TrainConfig.shuffle_minibatches
TrainConfig.clip_param
TrainConfig.vf_clip_param
TrainConfig.dual_clip_param
TrainConfig.vf_coeff
TrainConfig.target_kl_div
TrainConfig.max_grad_norm
TrainConfig.normalize_advantages
TrainConfig.normalize_rewards
TrainConfig.device
TrainConfig.recurrent
TrainConfig.build()
TrainConfig.from_file()
Trainer
Submodules
rl8.conditions module
Definitions for monitoring training metrics and determining whether metrics achieve some condition (most commonly useful for determining when to stop training).
- class rl8.conditions.Condition(*args, **kwargs)[source]
Bases:
Protocol
Condition callable that returns
True
if a condition is met.This is the interface used for early-stopping training.
- class rl8.conditions.And(conditions: list[rl8.conditions.Condition], /)[source]
Bases:
object
Convenience for joining results from multiple conditions with an
AND
.- Parameters:
conditions – Conditions to join results for with an
AND
.
- conditions: list[rl8.conditions.Condition]
Conditions to join results for with an
AND
.
- class rl8.conditions.HitsLowerBound(key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms'], lower_bound: float, /)[source]
Bases:
object
Condition that returns
True
if the value being monitored hits a lower bound value.- Parameters:
key – Key of train stat to monitor.
lower_bound – Minimum threshold for the value of
key
to reach before this condition returnsTrue
when called.
- key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms']
Key of train stat to inspect when called.
- class rl8.conditions.HitsUpperBound(key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms'], upper_bound: float, /)[source]
Bases:
object
Condition that returns
True
if the value being monitored hits an upper bound value.- Parameters:
key – Key of train stat to monitor.
upper_bound – Maximum threshold for the value of
key
to reach before this condition returnsTrue
when called.
- key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms']
Key of train stat to inspect when called.
- class rl8.conditions.Plateaus(key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms'], /, *, patience: int = 5, rtol: float = 0.001)[source]
Bases:
object
Condition that returns
True
if the value being monitored plateaus forpatience
number of times.- Parameters:
key – Key of train stat to monitor.
patience – Threshold for
Plateaus.losses
to reach for the condition to returnTrue
.rtol – Relative tolerance when comparing values of
Plateaus.key
between calls to determine if the call contributes toPlateaus.losses
.
- key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms']
Key of train stat to inspect when called.
- patience: int
Threshold for
Plateaus.losses
to reach for the condition to returnTrue
.
- rtol: float
Relative tolerance when comparing values of
Plateaus.key
between calls to determine if the call contributes toPlateaus.losses
.
- losses: int
Number of times the value of
Plateaus.key
has been withinPlateaus.rtol
in a row. If this reachesPlateaus.patience
, then the condition is met and this condition returnsTrue
.
- old_value: float
Last value of
Plateaus.key
.
- class rl8.conditions.StopsDecreasing(key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms'], /, *, patience: int = 5)[source]
Bases:
object
Condition that returns
True
if the value being monitored keeps the same minimum forpatience
number of times.- Parameters:
key – Key of train stat to monitor.
patience – Threshold for
StopsDecreasing.losses
to reach for the condition to returnTrue
.
- key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms']
Key of train stat to inspect when called.
- patience: int
Threshold for
StopsDecreasing.losses
to reach for the condition to returnTrue
.
- losses: int
Number of times the value of
StopsDecreasing.key
has not passedStopsDecreasing.min_
. If this reachesStopsDecreasing.patience
, then the condition is met and this condition returnsTrue
.
- min_: float
Last value of
StopsDecreasing.key
.
- class rl8.conditions.StopsIncreasing(key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms'], /, *, patience: int = 5)[source]
Bases:
object
Condition that returns
True
if the value being monitored keeps the same maximum forpatience
number of times.- Parameters:
key – Key of train stat to monitor.
patience – Threshold for
StopsIncreasing.losses
to reach for the condition to returnTrue
.
- key: Literal['algorithm/collects', 'algorithm/steps', 'env/resets', 'env/steps', 'profiling/collect_ms', 'returns/min', 'returns/max', 'returns/mean', 'returns/std', 'rewards/min', 'rewards/max', 'rewards/mean', 'rewards/std', 'coefficients/entropy', 'coefficients/vf', 'losses/entropy', 'losses/policy', 'losses/vf', 'losses/total', 'memory/free', 'memory/total', 'memory/percent', 'monitors/kl_div', 'profiling/step_ms']
Key of train stat to inspect when called.
- patience: int
Threshold for
StopsIncreasing.losses
to reach for the condition to returnTrue
.
- losses: int
Number of times the value of
StopsIncreasing.key
has not passedStopsIncreasing.max_
. If this reachesStopsIncreasing.patience
, then the condition is met and this condition returnsTrue
.
- max_: float
Last value of
StopsIncreasing.key
.
rl8.distributions module
- class rl8.distributions.Distribution(features: TensorDict, model: Any, /)[source]
Bases:
ABC
Policy component that defines a probability distribution over a feature set from a model.
This definition is largely inspired by RLlib’s action distribution. Most commonly, the feature set is a single vector of logits or log probabilities used for defining and sampling from the probability distribution. Custom probabiltiy distributions, however, are not constrained to just a single vector.
- Parameters:
features – Features from
model
’s forward pass.model – Model for parameterizing the probability distribution.
- features: TensorDict
Features from
Distribution.model
forward pass. Simple action distributions expect one field and corresponding tensor in the tensordict, but custom action distributions can return any kind of tensordict fromDistribution.model
.
- model: Any
Model from the parent policy also passed to the action distribution. This is necessary in case the model has components that’re only used for sampling or probability distribution characteristics computations.
- static default_dist_cls(action_spec: TensorSpec, /) type['Distribution'] [source]
Return a default distribution given an action spec.
- Parameters:
action_spec – Spec defining required environment inputs.
- Returns:
A distribution for simple, supported action specs.
- abstract deterministic_sample() Tensor | TensorDict [source]
Draw a deterministic sample from the probability distribution.
- abstract entropy() Tensor [source]
Compute the probability distribution’s entropy (a measurement of randomness).
- class rl8.distributions.TorchDistributionWrapper(features: TensorDict, model: Any, /)[source]
Bases:
Distribution
,Generic
[_FeatureSpec
,_TorchDistribution
,_ActionSpec
]Wrapper class for PyTorch distributions.
This is inspired by RLlib.
- dist: _TorchDistribution
Underlying PyTorch distribution.
- deterministic_sample() Tensor [source]
Draw a deterministic sample from the probability distribution.
- entropy() Tensor [source]
Compute the probability distribution’s entropy (a measurement of randomness).
- class rl8.distributions.Categorical(features: TensorDict, model: Any, /)[source]
Bases:
TorchDistributionWrapper
[CompositeSpec
,Categorical
,DiscreteTensorSpec
]Wrapper around the PyTorch categorical (i.e., discrete) distribution.
- class rl8.distributions.Normal(features: TensorDict, model: Any)[source]
Bases:
TorchDistributionWrapper
[CompositeSpec
,Normal
,UnboundedContinuousTensorSpec
]Wrapper around the PyTorch normal (i.e., gaussian) distribution.
- class rl8.distributions.SquashedNormal(features: TensorDict, model: Any)[source]
Bases:
Normal
Squashed normal distribution such that samples are always within [-1, 1].
- deterministic_sample() Tensor [source]
Draw a deterministic sample from the probability distribution.
- entropy() Tensor [source]
Compute the probability distribution’s entropy (a measurement of randomness).
rl8.env module
Environment protocol definition and helper dummy environment definitions.
- class rl8.env.Env(num_envs: int, /, horizon: None | int = None, *, config: None | dict[str, Any] = None, device: str | device = 'cpu')[source]
Bases:
ABC
Protocol defining the IsaacGym -like environments for supporting highly parallelized simulation.
To define your own custom environment, you must define the following instance attributes:
Env.action_spec
: The spec defining the environment’s inputs for its step function.Env.observation_spec
: The spec defining part of the environment’s outputs for its reset and step functions.
You must also define the following methods:
Env.reset()
: Returns the initial observation.Env.step()
: Takes an action and returns the updated environment observation and the new environment reward.
- Parameters:
num_envs – Number of parallel and independent environments being simulated by one
Env
instance.horizon – Number of steps the environment expects to take before being reset.
None
suggests the environment may never reset.config – Config detailing simulation options/parameters for the environment’s initialization.
device – Device the environment’s underlying data should be initialized on.
- action_spec: TensorSpec
Spec defining the environment’s inputs (and policy’s action distribution’s outputs). Used for initializing the policy, the policy’s underlying components, and the learning buffer.
- config: dict[str, Any]
Environment config passed to the environment at instantiation. This could be overwritten by the environment’s reset, but it’s entirely at the developer’s discretion.
- horizon: None | int
The number of steps the environment expects to be taken before being reset.
None
suggests the environment may never be reset, but this convention is not consistent.
- max_horizon: ClassVar[int]
An optional attribute denoting the max number of steps an environment may take before being reset. Used to validate environment instantiation.
- max_num_envs: ClassVar[int]
An optional attribute denoting the max number of parallel environments an environment instance may hold at any given time. Used to validate environment instantiation.
- observation_spec: TensorSpec
Spec defining part of the environment’s outputs (and policy’s model’s outputs). Used for initializing the policy, the policy’s underlying components, and the learning buffer.
- abstract reset(*, config: None | dict[str, Any] = None) Tensor | TensorDict [source]
Reset the environment, applying a new environment config to it and returning a new, initial observation from the environment.
- Parameters:
config – Environment configuration/options/parameters.
- Returns:
Initial observation from the reset environment with spec
Env.observation_spec
.
- abstract step(action: Tensor | TensorDict) TensorDict [source]
Step the environment by applying an action, simulating an environment transition, and returning an observation and a reward.
- Parameters:
action – Action to apply to the environment with tensor spec
Env.action_spec
.- Returns:
A tensordict containing “obs” and “rewards” keys and values.
- class rl8.env.EnvFactory(*args, **kwargs)[source]
Bases:
Protocol
Factory protocol describing how to create an environment instance.
- class rl8.env.GenericEnv(num_envs: int, /, horizon: None | int = None, *, config: None | dict[str, Any] = None, device: str | device = 'cpu')[source]
Bases:
Env
,Generic
[_ObservationSpec
,_ActionSpec
]Generic version of Env for environments with constant specs.
- observation_spec: _ObservationSpec
Environment observation spec.
- action_spec: _ActionSpec
Environment aciton spec.
- class rl8.env.DummyEnv(num_envs: int, /, horizon: None | int = None, *, config: None | dict[str, Any] = None, device: str | device = 'cpu')[source]
Bases:
GenericEnv
[UnboundedContinuousTensorSpec
,_ActionSpec
]The simplest environment possible.
Useful for testing and debugging algorithms and policies. The state is just a position along a 1D axis and the action perturbs the state by some amount. The reward is the negative of the state’s distance from the origin, incentivizing policies to drive the state to the origin as quickly as possible.
The environment’s action space and step functions are defined by subclasses.
- bounds: float
State magnitude bounds for generating initial states upon environment creation and environment resets.
- reset(*, config: None | dict[str, Any] = None) Tensor [source]
Reset the environment, applying a new environment config to it and returning a new, initial observation from the environment.
- Parameters:
config – Environment configuration/options/parameters.
- Returns:
Initial observation from the reset environment with spec
Env.observation_spec
.
- class rl8.env.ContinuousDummyEnv(num_envs: int, /, horizon: None | int = None, *, config: None | dict[str, Any] = None, device: str | device = 'cpu')[source]
Bases:
DummyEnv
[UnboundedContinuousTensorSpec
]A continuous version of the dummy environment.
Actions include moving the state left or right at any magnitude.
- step(action: Tensor) TensorDict [source]
Step the environment by applying an action, simulating an environment transition, and returning an observation and a reward.
- Parameters:
action – Action to apply to the environment with tensor spec
Env.action_spec
.- Returns:
A tensordict containing “obs” and “rewards” keys and values.
- class rl8.env.DiscreteDummyEnv(num_envs: int, /, horizon: None | int = None, *, config: None | dict[str, Any] = None, device: str | device = 'cpu')[source]
Bases:
DummyEnv
[DiscreteTensorSpec
]A discrete version of the dummy environment.
Actions include moving the state left or right one unit. This environment is considered more difficult to solve than its continuous counterpart because of the limited action space.
- step(action: Tensor) TensorDict [source]
Step the environment by applying an action, simulating an environment transition, and returning an observation and a reward.
- Parameters:
action – Action to apply to the environment with tensor spec
Env.action_spec
.- Returns:
A tensordict containing “obs” and “rewards” keys and values.
rl8.schedulers module
Schedulers for scheduling values, learning rates, and entropy.
- class rl8.schedulers.Scheduler(*args, **kwargs)[source]
Bases:
Protocol
Scheduler protocol for returning a value according to environment sample count.
- class rl8.schedulers.ConstantScheduler(value: float, /)[source]
Bases:
object
Scheduler that outputs a constant value.
This is the default scheduler when a schedule type is not provided.
- Parameters:
value – Constant value to output.
- class rl8.schedulers.InterpScheduler(schedule: list[tuple[int, float]], /)[source]
Bases:
object
Scheduler that interpolates between steps to new values when the number of environment samples exceeds a threshold.
- Parameters:
schedule – List of tuples where the first element of the tuple is the number of environment transitions needed to trigger a step and the second element of the tuple is the value to step to when a step is triggered.
- class rl8.schedulers.StepScheduler(schedule: list[tuple[int, float]], /)[source]
Bases:
object
Scheduler that steps to a new value when the number of environment samples exceeds a threshold.
This is the default scheduler when a schedule is provided.
- Parameters:
schedule – List of tuples where the first element of the tuple is the number of environment transitions needed to trigger a step and the second element of the tuple is the value to step to when a step is triggered.
- class rl8.schedulers.EntropyScheduler(coeff: float, /, *, schedule: None | list[tuple[int, float]] = None, kind: Literal['interp', 'step'] = 'step')[source]
Bases:
object
Entropy scheduler for scheduling entropy coefficients based on environment transition counts during learning.
- Parameters:
coeff – Entropy coefficient value. This value is ignored if a
schedule
is provided.schedule – Optional schedule that overrides
coeff
. This determines values of coeff according to the number of environment transitions experienced during learning.kind –
Kind of scheduler to use. Options include:
- ”step”: jump to values and hold until a new environment transition
count is reached.
- ”interp”: jump to values like “step”, but interpolate between the
current value and the next value.
- class rl8.schedulers.LRScheduler(optimizer: Optimizer, /, *, schedule: None | list[tuple[int, float]] = None, kind: Literal['interp', 'step'] = 'step')[source]
Bases:
object
Learning rate scheduler for scheduling optimizer learning rates based on environment transition counts during learning.
- Parameters:
optimizer – Optimizer to update with each
LRScheduler.step()
.schedule – Optional schedule that overrides the optimizer’s learning rate. This determines values of the learning rate according to the number of environment transitions experienced during learning.
kind –
Kind of scheduler to use. Options include:
- ”step”: jump to values and hold until a new environment transition
count is reached.
- ”interp”: jump to values like “step”, but interpolate between the
current value and the next value.
- scheduler: Scheduler
Backend value scheduler used. The type depends on if a
schedule
arg is provided andkind
’s value.
- coeff: float
Current learning rate according to the schedule.
0
untilLRScheduler.step()
is called for the first time.
rl8.views module
Definitions regarding applying views to batches of tensors or tensor dicts.
- class rl8.views.View(*args, **kwargs)[source]
Bases:
Protocol
A view requirement protocol for processing batch elements during policy sampling and training.
Supports applying methods to a batch of size
[B, T, ...]
(whereB
is the batch dimension, andT
is the time or sequence dimension) for all elements ofB
andT
or just the last elements ofT
for allB
.- static apply_all(x: Tensor | TensorDict, size: int, /) Tensor | TensorDict [source]
Apply the view to all elements of
B
andT
in a batch of size[B, T, ...]
such that the returned batch is of shape[B_NEW, size, ...]
whereB_NEW <= B * T
.
- rl8.views.pad_last_sequence(x: Tensor, size: int, /) TensorDict [source]
Pad the given tensor
x
along the time or sequence dimension such that the tensor’s time or sequence dimension is of sizesize
when selecting the lastsize
elements of the sequence.- Parameters:
x – Tensor of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.size – Minimum size of the sequence to select over
x
’sT
dimension.
- Returns:
A tensordict with key
"inputs"
corresponding to the padded (or not padded) elements, and key"padding_mask"
corresponding to booleans indicating which elements of"inputs"
are padding.
- rl8.views.pad_whole_sequence(x: Tensor, size: int, /) TensorDict [source]
Pad the given tensor
x
along the time or sequence dimension such that the tensor’s time or sequence dimension is of sizesize
after applyingrolling_window()
to the tensor.- Parameters:
x – Tensor of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, and T is typically the number of time steps or observations sampled from each environment.size – Required sequence size for each batch element in
x
.
- Returns:
A tensordict with key
"inputs"
corresponding to the padded (or not padded) elements, and key “padding_mask” corresponding to booleans indicating which elements of"inputs"
are padding.
- rl8.views.rolling_window(x: Tensor, size: int, /, *, step: int = 1) Tensor [source]
Unfold the given tensor
x
along the time or sequence dimension such that the tensor’s time or sequence dimension is mapped into two additional dimensions that represent a rolling window of sizesize
and stepstep
over the time or sequence dimension.See PyTorch’s unfold for details on PyTorch’s vanilla unfolding that does most of the work.
- Parameters:
x – Tensor of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.size – Size of the rolling window to create over
x
’s T dimension. The new sequence dimension is placed in the 2nd dimension.step – Number of steps to take when iterating over
x
’sT
dimension to create a new sequence of sizesize
.
- Returns:
A new tensor of shape
[B, (T - size) / step + 1, size, ...]
.
- class rl8.views.RollingWindow[source]
Bases:
object
A view that creates a rolling window of an item’s time or sequence dimension without masking (at the expense of losing some samples at the beginning of each sequence).
- static apply_all(x: Tensor | TensorDict, size: int, /) Tensor | TensorDict [source]
Unfold the given tensor or tensordict along the time or sequence dimension such that the the time or sequence dimension becomes a rolling window of size
size
. The new time or sequence dimension is also expanded into the batch dimension such that each new sequence becomes an additional batch element.The expanded batch dimension has sample loss because the initial
size - 1
samples are required to make a sequence of sizesize
.- Parameters:
x – Tensor or tensordict of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.size – Size of the rolling window to create over
x
’sT
dimension. The new sequence dimension is placed in the 2nd dimension.
- Returns:
A new tensor or tensordict of shape
[B * (T - size + 1), size, ...]
.
- static apply_last(x: Tensor | TensorDict, size: int, /) Tensor | TensorDict [source]
Grab the last
size
elements ofx
along the time or sequence dimension.- Parameters:
x – Tensor or tensordict of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.size – Number of “last” samples to grab along the time or sequence dimension
T
.
- Returns:
A new tensor or tensordict of shape
[B, size, ...]
.
- class rl8.views.PaddedRollingWindow[source]
Bases:
object
A view that creates a rolling window of an item’s time or sequence dimension with padding and masking to make all batch elements the same size.
This is effectively the same as
RollingWindow
but with padding and masking applied beforehand.- static apply_all(x: Tensor | TensorDict, size: int, /) TensorDict [source]
Unfold the given tensor or tensordict along the time or sequence dimension such that the the time or sequence dimension becomes a rolling window of size
size
. The new time or sequence dimension is also expanded into the batch dimension such that each new sequence becomes an additional batch element.The expanded batch dimension is always size
B * T
because this view pads and masks to enforce all seqeunce elements to be used.- Parameters:
x – Tensor or tensordict of size
[B, T, ...]
where B is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.size – Size of the rolling window to create over x’s
T
dimension. The new sequence dimension is placed in the 2nd dimension.
- Returns:
A new tensor or tensordict of shape
[B * T, size, ...]
.
- static apply_last(x: Tensor | TensorDict, size: int, /) TensorDict [source]
Grab the last
size
elements ofx
along the time or sequence dimension, and pad and mask to force the sequence to be of sizesize
.- Parameters:
x – Tensor or tensordict of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.size – Number of “last” samples to grab along the time or sequence dimension
T
.
- Returns:
A new tensor or tensordict of shape
[B, size, ...]
.
- class rl8.views.ViewRequirement(*, shift: int = 0, method: Literal['rolling_window', 'padded_rolling_window'] = 'padded_rolling_window')[source]
Bases:
object
Batch preprocessing for creating overlapping time series or sequential environment observations that’s applied prior to feeding samples into a policy’s model.
This component is purely for convenience. Its functionality can optionally be replicated within an environment’s observation function. However, because this functionaltiy is fairly common, it’s recommended to use this component where simple time or sequence shifting is required for sequence-based observations.
- Parameters:
shift – Number of additional previous samples in the time or sequence dimension to include in the view requirement’s output.
method –
Method for applying a nonzero shift view requirement. Options include:
”rolling_window”: Create a rolling window over a tensor’s time or sequence dimension at the cost of dropping samples early into the sequence in order to force all sequences to be the same size.
”padded_rolling_window”: The same as “rolling_window” but pad the beginning of each sequence to avoid dropping samples and provide a mask indicating which element is padding.
- shift: int
Number of additional previous samples in the time or sequence dimension to include in the view requirement’s output. E.g., if shift is
1
, then the last two samples in the time or sequence dimension will be included for each batch element.
- method: type[rl8.views.View]
Method for applying a nonzero shift view requirement. Each method has its own advantage. Options include:
"rolling_window"
: Create a rolling window over a tensor’stime or sequence dimension. This method is memory-efficient and fast, but it drops samples in order for each new batch element to have the same sequence size. Only use this method if the view requirement’s shift is much smaller than an environment’s horizon.
"padded_rolling_window"
: The same as"rolling_window"
, but itpads the beginning of each sequence so no samples are dropped. This method also provides a padding mask for each tensor or tensor dict to indicate which sequence element is padding.
- apply_all(key: str | tuple[str, ...], batch: TensorDict, /) Tensor | TensorDict [source]
Apply the view to all of the time or sequence elements.
This method expands the elements of
batch
’s first two dimensions together to allow parallel batching of batch’s elements in the batch and time or sequence dimension together. This method is typically used within a training loop and isn’t typically used for sampling a policy’s actions or environment interaction.- Parameters:
key – Key to apply the view requirement to for a given batch. The key can be any key that is compatible with a tensordict key. E.g., a key can be a tuple of strings such that the item in the batch is accessed like
batch[("obs", "prices")]
.batch – Tensor dict of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.
- Returns:
A tensor or tensordict of size
[B_NEW, self.shift, ...]
whereB_NEW <= B * T
, depending on the view requirement method applied. In the case whereViewRequirement.shift
is0
, the return tensor or tensordict has size[B * T, ...]
.
- apply_last(key: str | tuple[str, ...], batch: TensorDict, /) Tensor | TensorDict [source]
Apply the view to just the last time or sequence elements.
This method is typically used for sampling a model’s features and eventual sampling of a policy’s actions for parallel environments.
- Parameters:
key – Key to apply the view requirement to for a given batch. The key can be any key that is compatible with a tensordict key. E.g., a key can be a tuple of strings such that the item in the batch is accessed like
batch[("obs", "prices")]
.batch – Tensor dict of size
[B, T, ...]
whereB
is the batch dimension, andT
is the time or sequence dimension.B
is typically the number of parallel environments, andT
is typically the number of time steps or observations sampled from each environment.
- Returns:
A tensor or tensordict of size
[B, self.shift + 1, ...]
. In the case whereViewRequirement.shift
is0
, the returned tensor or tensordict has size[B, ...]
.
Module contents
Top-level package interface.