Source code for rl8.models._base

from abc import ABCMeta, abstractmethod
from typing import Any, ParamSpec, TypeVar

import torch
from torchrl.data import TensorSpec
from typing_extensions import Self

from ..data import Device
from ..nn import Module

_P = ParamSpec("_P")
_T = TypeVar("_T")


[docs]class GenericModelBase(Module[_P, _T], metaclass=ABCMeta): """The base for the policy component that processes environment observations into a value function approximation and features to be consumed by an action distribution for action sampling. All model flavors (i.e., feedforward or recurrent) inherit from this base. This just provides a common interface for different model flavors. Args: observation_spec: Spec defining the forward pass input. action_spec: Spec defining the outputs of the policy's action distribution that this model is a component of. config: Model-specific configuration. """ #: Spec defining the outputs of the policy's action distribution that #: this model is a component of. Useful for defining the model as a #: function of the action spec. action_spec: TensorSpec #: Model-specific configuration. Passed from the policy and algorithm. config: dict[str, Any] #: Spec defining the forward pass input. Useful for validating the forward #: pass and for defining the model as a function of the observation spec. observation_spec: TensorSpec def __init__( self, observation_spec: TensorSpec, action_spec: TensorSpec, /, **config: Any, ) -> None: super().__init__() self.observation_spec = observation_spec self.action_spec = action_spec self.config = config @property def device(self) -> Device: """Return the device the model is currently on.""" return next(self.parameters()).device
[docs] @abstractmethod def to(self, device: Device) -> Self: # type: ignore[override] """Helper for changing the device the model is on. The specs associated with the model aren't updated with the PyTorch module's ``to`` method since they aren't PyTorch modules themselves. Args: device: Target device. Returns: The updated model. """
[docs] @abstractmethod def value_function(self) -> torch.Tensor: """Return the value function output for the most recent forward pass. Note that a :meth`GenericModelBase.forward` call has to be performed first before this method can return anything. This helps prevent extra forward passes from being performed just to get a value function output in case the value function and action distribution components share parameters. """