Source code for rl8.env

"""Environment protocol definition and helper dummy environment definitions."""

from abc import ABC, abstractmethod
from typing import Any, ClassVar, Generic, Protocol, TypeVar

import torch
from tensordict import TensorDict
from torchrl.data import Categorical, TensorSpec, Unbounded

from .data import DataKeys, Device

_ObservationSpec = TypeVar("_ObservationSpec", bound=TensorSpec)
_ActionSpec = TypeVar("_ActionSpec", bound=TensorSpec)


[docs]class Env(ABC): """Protocol defining the IsaacGym -like environments for supporting highly parallelized simulation. To define your own custom environment, you must define the following instance attributes: - :attr:`Env.action_spec`: The spec defining the environment's inputs for its step function. - :attr:`Env.observation_spec`: The spec defining part of the environment's outputs for its reset and step functions. You must also define the following methods: - :meth:`Env.reset`: Returns the initial observation. - :meth:`Env.step`: Takes an action and returns the updated environment observation and the new environment reward. Args: num_envs: Number of parallel and independent environments being simulated by one :class:`Env` instance. horizon: Number of steps the environment expects to take before being reset. ``None`` suggests the environment may never reset. device: Device the environment's underlying data should be initialized on. """ #: Spec defining the environment's inputs (and policy's action #: distribution's outputs). Used for initializing the policy, the #: policy's underlying components, and the learning buffer. action_spec: TensorSpec #: Device the environment's states, observations, and rewards reside #: on. device: Device #: The number of steps the environment expects to be taken before being #: reset. ``None`` suggests the environment may never be reset, but #: this convention is not consistent. horizon: None | int #: An optional attribute denoting the max number of steps an environment #: may take before being reset. Used to validate environment instantiation. max_horizon: ClassVar[int] #: An optional attribute denoting the max number of parallel environments #: an environment instance may hold at any given time. Used to validate #: environment instantiation. max_num_envs: ClassVar[int] #: Number of parallel and independent environments being simulated. num_envs: int #: Spec defining part of the environment's outputs (and policy's #: model's outputs). Used for initializing the policy, the #: policy's underlying components, and the learning buffer. observation_spec: TensorSpec def __init__( self, num_envs: int, /, horizon: None | int = None, *, device: Device = "cpu", ) -> None: if hasattr(self, "max_horizon") and horizon is not None: if not (horizon <= self.max_horizon): raise ValueError( f"{self.__class__.__name__} `horizon` must be <=" f" {self.max_horizon}." ) if hasattr(self, "max_num_envs"): if not (num_envs <= self.max_num_envs): raise ValueError( f"{self.__class__.__name__} `num_envs` must be <=" f" {self.max_num_envs}." ) self.num_envs = num_envs self.horizon = horizon self.device = device
[docs] @abstractmethod def reset( self, *, config: None | dict[str, Any] = None ) -> torch.Tensor | TensorDict: """Reset the environment, applying a new environment config to it and returning a new, initial observation from the environment. Args: config: Environment configuration/options/parameters. Returns: Initial observation from the reset environment with spec :attr:`Env.observation_spec`. """
[docs] @abstractmethod def step(self, action: torch.Tensor | TensorDict) -> TensorDict: """Step the environment by applying an action, simulating an environment transition, and returning an observation and a reward. Args: action: Action to apply to the environment with tensor spec :attr:`Env.action_spec`. Returns: A tensordict containing "obs" and "rewards" keys and values. """
[docs]class EnvFactory(Protocol): """Factory protocol describing how to create an environment instance.""" #: An optional attribute denoting the max number of steps an environment #: may take before being reset. Used to validate environment instantiation. max_horizon: ClassVar[int] #: An optional attribute denoting the max number of parallel environments #: an environment instance may hold at any given time. Used to validate #: environment instantiation. max_num_envs: ClassVar[int] def __call__( self, num_envs: int, /, horizon: None | int = None, *, device: Device = "cpu", ) -> Env: ...
[docs]class GenericEnv(Env, Generic[_ObservationSpec, _ActionSpec]): """Generic version of `Env` for environments with constant specs.""" #: Environment observation spec. observation_spec: _ObservationSpec #: Environment aciton spec. action_spec: _ActionSpec
[docs]class DummyEnv(GenericEnv[Unbounded, _ActionSpec]): """The simplest environment possible. Useful for testing and debugging algorithms and policies. The state is just a position along a 1D axis and the action perturbs the state by some amount. The reward is the negative of the state's distance from the origin, incentivizing policies to drive the state to the origin as quickly as possible. The environment's action space and step functions are defined by subclasses. """ #: State magnitude bounds for generating initial states upon #: environment creation and environment resets. bounds: float #: Current environment state that's a position along a 1D axis. state: torch.Tensor def __init__( self, num_envs: int, /, horizon: None | int = None, *, device: Device = "cpu", ) -> None: super().__init__(num_envs, horizon, device=device) self.observation_spec = Unbounded(1, device=self.device) self.bounds = 100.0
[docs] def reset(self, *, config: None | dict[str, Any] = None) -> torch.Tensor: config = config or {} self.bounds = config.get("bounds", self.bounds) self.state = torch.empty(self.num_envs, 1, device=self.device).uniform_( -self.bounds, self.bounds ) return self.state
[docs]class ContinuousDummyEnv(DummyEnv[Unbounded]): """A continuous version of the dummy environment. Actions include moving the state left or right at any magnitude. """ def __init__( self, num_envs: int, /, horizon: None | int = None, *, device: Device = "cpu", ) -> None: super().__init__(num_envs, horizon, device=device) self.action_spec = Unbounded(shape=torch.Size([1]), device=device)
[docs] def step(self, action: torch.Tensor) -> TensorDict: self.state += action return TensorDict( {DataKeys.OBS: self.state, DataKeys.REWARDS: -self.state.abs()}, batch_size=self.num_envs, device=self.device, )
[docs]class DiscreteDummyEnv(DummyEnv[Categorical]): """A discrete version of the dummy environment. Actions include moving the state left or right one unit. This environment is considered more difficult to solve than its continuous counterpart because of the limited action space. """ def __init__( self, num_envs: int, /, horizon: None | int = None, *, device: Device = "cpu", ) -> None: super().__init__(num_envs, horizon, device=device) self.action_spec = Categorical(2, shape=torch.Size([1]), device=device)
[docs] def step(self, action: torch.Tensor) -> TensorDict: self.state += 2 * action - 1 return TensorDict( {DataKeys.OBS: self.state, DataKeys.REWARDS: -self.state.abs()}, batch_size=self.num_envs, device=self.device, )