"""Definitions regarding applying views to batches of tensors or tensor dicts."""
from typing import Literal, Protocol
import torch
from tensordict import TensorDict
from .data import DataKeys
ViewKind = Literal["last", "all"]
ViewMethod = Literal["rolling_window", "padded_rolling_window"]
[docs]class View(Protocol):
"""A view requirement protocol for processing batch elements during policy
sampling and training.
Supports applying methods to a batch of size ``[B, T, ...]`` (where ``B`` is the
batch dimension, and ``T`` is the time or sequence dimension) for all elements
of ``B`` and ``T`` or just the last elements of ``T`` for all ``B``.
"""
[docs] @staticmethod
def apply_all(
x: torch.Tensor | TensorDict, size: int, /
) -> torch.Tensor | TensorDict:
"""Apply the view to all elements of ``B`` and ``T`` in a batch of size
``[B, T, ...]`` such that the returned batch is of shape ``[B_NEW, size, ...]``
where ``B_NEW <= B * T``.
"""
[docs] @staticmethod
def apply_last(
x: torch.Tensor | TensorDict, size: int, /
) -> torch.Tensor | TensorDict:
"""Apply the view to the last elements of ``T`` for all ``B`` in a batch
of size ``[B, T, ...]`` such that the returned batch is of shape
``[B, size, ...]``.
"""
[docs] @staticmethod
def drop_size(size: int, /) -> int:
"""Return the amount of samples along the time or sequence dimension
that's dropped for each batch element.
This is used to determine batch size reshaping during training to make
batch components have the same size.
"""
[docs]def pad_last_sequence(x: torch.Tensor, size: int, /) -> TensorDict:
"""Pad the given tensor ``x`` along the time or sequence dimension such
that the tensor's time or sequence dimension is of size ``size`` when
selecting the last ``size`` elements of the sequence.
Args:
x: Tensor of size ``[B, T, ...]`` where ``B`` is the batch dimension, and
``T`` is the time or sequence dimension. ``B`` is typically the number
of parallel environments, and ``T`` is typically the number of time
steps or observations sampled from each environment.
size: Minimum size of the sequence to select over ``x``'s ``T`` dimension.
Returns:
A tensordict with key ``"inputs"`` corresponding to the padded (or not
padded) elements, and key ``"padding_mask"`` corresponding to booleans
indicating which elements of ``"inputs"`` are padding.
"""
B, T = x.shape[:2]
pad = size - T
if pad > 0:
F = x.shape[2:]
padding = torch.zeros(B, pad, *F, device=x.device, dtype=x.dtype)
x = torch.cat([padding, x], 1)
padding_mask = torch.zeros(B, size, device=x.device, dtype=torch.bool)
padding_mask[:, :pad] = True
else:
x = x[:, -size:, ...]
padding_mask = torch.zeros(B, size, device=x.device, dtype=torch.bool)
out = TensorDict({}, batch_size=[B, size], device=x.device)
out[DataKeys.INPUTS] = x
out[DataKeys.PADDING_MASK] = padding_mask
return out
[docs]def pad_whole_sequence(x: torch.Tensor, size: int, /) -> TensorDict:
"""Pad the given tensor ``x`` along the time or sequence dimension such
that the tensor's time or sequence dimension is of size ``size`` after
applying :meth:`rolling_window` to the tensor.
Args:
x: Tensor of size ``[B, T, ...]`` where ``B`` is the batch dimension, and
``T`` is the time or sequence dimension. ``B`` is typically the number
of parallel environments, and T is typically the number of time
steps or observations sampled from each environment.
size: Required sequence size for each batch element in ``x``.
Returns:
A tensordict with key ``"inputs"`` corresponding to the padded (or not
padded) elements, and key "padding_mask" corresponding to booleans
indicating which elements of ``"inputs"`` are padding.
"""
B, T = x.shape[:2]
F = x.shape[2:]
pad = RollingWindow.drop_size(size)
padding = torch.zeros(B, pad, *F, device=x.device, dtype=x.dtype)
x = torch.cat([padding, x], 1)
padding_mask = torch.zeros(B, T + pad, device=x.device, dtype=torch.bool)
padding_mask[:, :pad] = True
out = TensorDict({}, batch_size=[B, T + pad], device=x.device)
out[DataKeys.INPUTS] = x
out[DataKeys.PADDING_MASK] = padding_mask
return out
[docs]def rolling_window(x: torch.Tensor, size: int, /, *, step: int = 1) -> torch.Tensor:
"""Unfold the given tensor ``x`` along the time or sequence dimension such
that the tensor's time or sequence dimension is mapped into two
additional dimensions that represent a rolling window of size ``size``
and step ``step`` over the time or sequence dimension.
See PyTorch's `unfold`_ for details on PyTorch's vanilla unfolding that does
most of the work.
Args:
x: Tensor of size ``[B, T, ...]`` where ``B`` is the batch dimension, and
``T`` is the time or sequence dimension. ``B`` is typically the number
of parallel environments, and ``T`` is typically the number of time
steps or observations sampled from each environment.
size: Size of the rolling window to create over ``x``'s `T` dimension.
The new sequence dimension is placed in the 2nd dimension.
step: Number of steps to take when iterating over ``x``'s ``T`` dimension
to create a new sequence of size ``size``.
Returns:
A new tensor of shape ``[B, (T - size) / step + 1, size, ...]``.
.. _`unfold`: https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
"""
dims = [i for i in range(x.dim())]
dims.insert(2, -1)
return x.unfold(1, size, step).permute(*dims)
[docs]class RollingWindow:
"""A view that creates a rolling window of an item's time or sequence
dimension without masking (at the expense of losing some samples at
the beginning of each sequence).
"""
[docs] @staticmethod
def apply_all(
x: torch.Tensor | TensorDict, size: int, /
) -> torch.Tensor | TensorDict:
"""Unfold the given tensor or tensordict along the time or sequence
dimension such that the the time or sequence dimension becomes a
rolling window of size ``size``. The new time or sequence dimension is
also expanded into the batch dimension such that each new sequence
becomes an additional batch element.
The expanded batch dimension has sample loss because the initial
``size - 1`` samples are required to make a sequence of size ``size``.
Args:
x: Tensor or tensordict of size ``[B, T, ...]`` where ``B`` is the
batch dimension, and ``T`` is the time or sequence dimension.
``B`` is typically the number of parallel environments, and ``T``
is typically the number of time steps or observations sampled
from each environment.
size: Size of the rolling window to create over ``x``'s ``T`` dimension.
The new sequence dimension is placed in the 2nd dimension.
Returns:
A new tensor or tensordict of shape
``[B * (T - size + 1), size, ...]``.
"""
if isinstance(x, torch.Tensor):
E = x.shape[2:]
return rolling_window(x, size, step=1).reshape(-1, size, *E)
else:
B_OLD, T_OLD = x.shape[:2]
T_NEW = T_OLD - size + 1
return x.apply(
lambda x: rolling_window(x, size, step=1), batch_size=[B_OLD, T_NEW]
).reshape(-1)
[docs] @staticmethod
def apply_last(
x: torch.Tensor | TensorDict, size: int, /
) -> torch.Tensor | TensorDict:
"""Grab the last ``size`` elements of ``x`` along the time or sequence
dimension.
Args:
x: Tensor or tensordict of size ``[B, T, ...]`` where ``B`` is the
batch dimension, and ``T`` is the time or sequence dimension.
``B`` is typically the number of parallel environments, and ``T``
is typically the number of time steps or observations sampled
from each environment.
size: Number of "last" samples to grab along the time or sequence
dimension ``T``.
Returns:
A new tensor or tensordict of shape ``[B, size, ...]``.
"""
if isinstance(x, torch.Tensor):
return x[:, -size:, ...]
else:
B, T = x.shape[:2]
T_NEW = min(T, size)
return x.apply(lambda x: x[:, -size:, ...], batch_size=[B, T_NEW])
[docs] @staticmethod
def drop_size(size: int, /) -> int:
"""This view doesn't perform any padding or masking and instead
drops a small amount of samples at the beginning of each
sequence in order to create sequences of the same length.
"""
return size - 1
[docs]class PaddedRollingWindow:
"""A view that creates a rolling window of an item's time or sequence
dimension with padding and masking to make all batch elements the same
size.
This is effectively the same as :class:`RollingWindow` but with padding and
masking applied beforehand.
"""
[docs] @staticmethod
def apply_all(x: torch.Tensor | TensorDict, size: int, /) -> TensorDict:
"""Unfold the given tensor or tensordict along the time or sequence
dimension such that the the time or sequence dimension becomes a
rolling window of size ``size``. The new time or sequence dimension is
also expanded into the batch dimension such that each new sequence
becomes an additional batch element.
The expanded batch dimension is always size ``B * T`` because this view
pads and masks to enforce all seqeunce elements to be used.
Args:
x: Tensor or tensordict of size ``[B, T, ...]`` where B is the
batch dimension, and ``T`` is the time or sequence dimension.
``B`` is typically the number of parallel environments, and ``T``
is typically the number of time steps or observations sampled
from each environment.
size: Size of the rolling window to create over `x`'s ``T`` dimension.
The new sequence dimension is placed in the 2nd dimension.
Returns:
A new tensor or tensordict of shape
``[B * T, size, ...]``.
"""
if isinstance(x, torch.Tensor):
return RollingWindow.apply_all(pad_whole_sequence(x, size), size)
else:
B_OLD, T_OLD = x.shape[:2]
T_NEW = T_OLD + RollingWindow.drop_size(size)
return RollingWindow.apply_all(
x.apply(
lambda x: pad_whole_sequence(x, size), batch_size=[B_OLD, T_NEW]
),
size,
)
[docs] @staticmethod
def apply_last(x: torch.Tensor | TensorDict, size: int, /) -> TensorDict:
"""Grab the last ``size`` elements of ``x`` along the time or sequence
dimension, and pad and mask to force the sequence to be of size ``size``.
Args:
x: Tensor or tensordict of size ``[B, T, ...]`` where ``B`` is the
batch dimension, and ``T`` is the time or sequence dimension.
``B`` is typically the number of parallel environments, and ``T``
is typically the number of time steps or observations sampled
from each environment.
size: Number of "last" samples to grab along the time or sequence
dimension ``T``.
Returns:
A new tensor or tensordict of shape ``[B, size, ...]``.
"""
if isinstance(x, torch.Tensor):
return pad_last_sequence(x, size)
else:
B = x.size(0)
return x.apply(lambda x: pad_last_sequence(x, size), batch_size=[B, size])
[docs] @staticmethod
def drop_size(size: int, /) -> int:
"""This view pads the beginning of each sequence and provides masking
to avoid dropping-off samples.
"""
return size - size
[docs]class ViewRequirement:
"""Batch preprocessing for creating overlapping time series or sequential
environment observations that's applied prior to feeding samples into a
policy's model.
This component is purely for convenience. Its functionality can optionally
be replicated within an environment's observation function. However, because
this functionaltiy is fairly common, it's recommended to use this
component where simple time or sequence shifting is required for
sequence-based observations.
Args:
shift: Number of additional previous samples in the time or sequence
dimension to include in the view requirement's output.
method: Method for applying a nonzero shift view requirement.
Options include:
- "rolling_window": Create a rolling window over a tensor's
time or sequence dimension at the cost of dropping
samples early into the sequence in order to force all
sequences to be the same size.
- "padded_rolling_window": The same as "rolling_window" but
pad the beginning of each sequence to avoid dropping
samples and provide a mask indicating which element is
padding.
"""
#: Method for applying a nonzero shift view requirement. Each method
#: has its own advantage. Options include:
#:
#: - ``"rolling_window"``: Create a rolling window over a tensor's
#: time or sequence dimension. This method is memory-efficient
#: and fast, but it drops samples in order for each new batch
#: element to have the same sequence size. Only use this method
#: if the view requirement's shift is much smaller than an
#: environment's horizon.
#: - ``"padded_rolling_window"``: The same as ``"rolling_window"``, but it
#: pads the beginning of each sequence so no samples are dropped. This
#: method also provides a padding mask for each tensor or tensor
#: dict to indicate which sequence element is padding.
method: type[View]
#: Number of additional previous samples in the time or sequence dimension
#: to include in the view requirement's output. E.g., if shift is ``1``,
#: then the last two samples in the time or sequence dimension will be
#: included for each batch element.
shift: int
def __init__(
self,
*,
shift: int = 0,
method: ViewMethod = "padded_rolling_window",
) -> None:
self.shift = shift
if shift < 0:
raise ValueError(f"{self.__class__.__name__} `shift` must be non-negative.")
match method:
case "rolling_window":
self.method = RollingWindow
case "padded_rolling_window":
self.method = PaddedRollingWindow
[docs] def apply_all(
self, key: str | tuple[str, ...], batch: TensorDict, /
) -> torch.Tensor | TensorDict:
"""Apply the view to all of the time or sequence elements.
This method expands the elements of ``batch``'s first two dimensions
together to allow parallel batching of `batch`'s elements in the batch
and time or sequence dimension together. This method is typically
used within a training loop and isn't typically used for sampling
a policy's actions or environment interaction.
Args:
key: Key to apply the view requirement to for a given batch. The key
can be any key that is compatible with a tensordict key.
E.g., a key can be a tuple of strings such that the item in the
batch is accessed like ``batch[("obs", "prices")]``.
batch: Tensor dict of size ``[B, T, ...]`` where ``B`` is the batch
dimension, and ``T`` is the time or sequence dimension. ``B`` is
typically the number of parallel environments, and ``T`` is
typically the number of time steps or observations sampled
from each environment.
Returns:
A tensor or tensordict of size ``[B_NEW, self.shift, ...]``
where ``B_NEW <= B * T``, depending on the view requirement method
applied. In the case where :attr:`ViewRequirement.shift` is ``0``,
the return tensor or tensordict has size ``[B * T, ...]``.
"""
item = batch[key]
with torch.no_grad():
if not self.shift:
if isinstance(item, torch.Tensor):
return item.flatten(end_dim=1)
else:
return item.reshape(-1)
return self.method.apply_all(item, self.shift + 1)
[docs] def apply_last(
self, key: str | tuple[str, ...], batch: TensorDict, /
) -> torch.Tensor | TensorDict:
"""Apply the view to just the last time or sequence elements.
This method is typically used for sampling a model's features
and eventual sampling of a policy's actions for parallel environments.
Args:
key: Key to apply the view requirement to for a given batch. The key
can be any key that is compatible with a tensordict key.
E.g., a key can be a tuple of strings such that the item in the
batch is accessed like ``batch[("obs", "prices")]``.
batch: Tensor dict of size ``[B, T, ...]`` where ``B`` is the batch
dimension, and ``T`` is the time or sequence dimension. ``B`` is
typically the number of parallel environments, and ``T`` is
typically the number of time steps or observations sampled
from each environment.
Returns:
A tensor or tensordict of size ``[B, self.shift + 1, ...]``. In the
case where :attr:`ViewRequirement.shift` is ``0``, the returned
tensor or tensordict has size ``[B, ...]``.
"""
item = batch[key]
with torch.no_grad():
if not self.shift:
return item[:, -1, ...]
return self.method.apply_last(item, self.shift + 1)
@property
def drop_size(self) -> int:
"""Return the number of samples dropped when using the underlying view requirement method.
"""
return self.method.drop_size(self.shift + 1)