"""Environment protocol definition and helper dummy environment definitions."""
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Generic, Protocol, TypeVar
import torch
from tensordict import TensorDict
from torchrl.data import Categorical, TensorSpec, Unbounded
from .data import DataKeys, Device
_ObservationSpec = TypeVar("_ObservationSpec", bound=TensorSpec)
_ActionSpec = TypeVar("_ActionSpec", bound=TensorSpec)
[docs]class Env(ABC):
"""Protocol defining the IsaacGym -like environments for supporting
highly parallelized simulation.
To define your own custom environment, you must define the following
instance attributes:
- :attr:`Env.action_spec`: The spec defining the environment's
inputs for its step function.
- :attr:`Env.observation_spec`: The spec defining part of the
environment's outputs for its reset and step functions.
You must also define the following methods:
- :meth:`Env.reset`: Returns the initial observation.
- :meth:`Env.step`: Takes an action and returns the updated
environment observation and the new environment reward.
Args:
num_envs: Number of parallel and independent environments being
simulated by one :class:`Env` instance.
horizon: Number of steps the environment expects to take before
being reset. ``None`` suggests the environment may never
reset.
device: Device the environment's underlying data should be
initialized on.
"""
#: Spec defining the environment's inputs (and policy's action
#: distribution's outputs). Used for initializing the policy, the
#: policy's underlying components, and the learning buffer.
action_spec: TensorSpec
#: Device the environment's states, observations, and rewards reside
#: on.
device: Device
#: The number of steps the environment expects to be taken before being
#: reset. ``None`` suggests the environment may never be reset, but
#: this convention is not consistent.
horizon: None | int
#: An optional attribute denoting the max number of steps an environment
#: may take before being reset. Used to validate environment instantiation.
max_horizon: ClassVar[int]
#: An optional attribute denoting the max number of parallel environments
#: an environment instance may hold at any given time. Used to validate
#: environment instantiation.
max_num_envs: ClassVar[int]
#: Number of parallel and independent environments being simulated.
num_envs: int
#: Spec defining part of the environment's outputs (and policy's
#: model's outputs). Used for initializing the policy, the
#: policy's underlying components, and the learning buffer.
observation_spec: TensorSpec
def __init__(
self,
num_envs: int,
/,
horizon: None | int = None,
*,
device: Device = "cpu",
) -> None:
if hasattr(self, "max_horizon") and horizon is not None:
if not (horizon <= self.max_horizon):
raise ValueError(
f"{self.__class__.__name__} `horizon` must be <="
f" {self.max_horizon}."
)
if hasattr(self, "max_num_envs"):
if not (num_envs <= self.max_num_envs):
raise ValueError(
f"{self.__class__.__name__} `num_envs` must be <="
f" {self.max_num_envs}."
)
self.num_envs = num_envs
self.horizon = horizon
self.device = device
[docs] @abstractmethod
def reset(
self, *, config: None | dict[str, Any] = None
) -> torch.Tensor | TensorDict:
"""Reset the environment, applying a new environment config to it and
returning a new, initial observation from the environment.
Args:
config: Environment configuration/options/parameters.
Returns:
Initial observation from the reset environment with spec
:attr:`Env.observation_spec`.
"""
[docs] @abstractmethod
def step(self, action: torch.Tensor | TensorDict) -> TensorDict:
"""Step the environment by applying an action, simulating an environment
transition, and returning an observation and a reward.
Args:
action: Action to apply to the environment with tensor spec
:attr:`Env.action_spec`.
Returns:
A tensordict containing "obs" and "rewards" keys and values.
"""
[docs]class EnvFactory(Protocol):
"""Factory protocol describing how to create an environment instance."""
#: An optional attribute denoting the max number of steps an environment
#: may take before being reset. Used to validate environment instantiation.
max_horizon: ClassVar[int]
#: An optional attribute denoting the max number of parallel environments
#: an environment instance may hold at any given time. Used to validate
#: environment instantiation.
max_num_envs: ClassVar[int]
def __call__(
self,
num_envs: int,
/,
horizon: None | int = None,
*,
device: Device = "cpu",
) -> Env:
...
[docs]class GenericEnv(Env, Generic[_ObservationSpec, _ActionSpec]):
"""Generic version of `Env` for environments with constant specs."""
#: Environment observation spec.
observation_spec: _ObservationSpec
#: Environment aciton spec.
action_spec: _ActionSpec
[docs]class DummyEnv(GenericEnv[Unbounded, _ActionSpec]):
"""The simplest environment possible.
Useful for testing and debugging algorithms and policies. The state
is just a position along a 1D axis and the action perturbs the
state by some amount. The reward is the negative of the state's distance
from the origin, incentivizing policies to drive the state to the
origin as quickly as possible.
The environment's action space and step functions are defined by
subclasses.
"""
#: State magnitude bounds for generating initial states upon
#: environment creation and environment resets.
bounds: float
#: Current environment state that's a position along a 1D axis.
state: torch.Tensor
def __init__(
self,
num_envs: int,
/,
horizon: None | int = None,
*,
device: Device = "cpu",
) -> None:
super().__init__(num_envs, horizon, device=device)
self.observation_spec = Unbounded(1, device=self.device)
self.bounds = 100.0
[docs] def reset(self, *, config: None | dict[str, Any] = None) -> torch.Tensor:
config = config or {}
self.bounds = config.get("bounds", self.bounds)
self.state = torch.empty(self.num_envs, 1, device=self.device).uniform_(
-self.bounds, self.bounds
)
return self.state
[docs]class ContinuousDummyEnv(DummyEnv[Unbounded]):
"""A continuous version of the dummy environment.
Actions include moving the state left or right at any magnitude.
"""
def __init__(
self,
num_envs: int,
/,
horizon: None | int = None,
*,
device: Device = "cpu",
) -> None:
super().__init__(num_envs, horizon, device=device)
self.action_spec = Unbounded(shape=torch.Size([1]), device=device)
[docs] def step(self, action: torch.Tensor) -> TensorDict:
self.state += action
return TensorDict(
{DataKeys.OBS: self.state, DataKeys.REWARDS: -self.state.abs()},
batch_size=self.num_envs,
device=self.device,
)
[docs]class DiscreteDummyEnv(DummyEnv[Categorical]):
"""A discrete version of the dummy environment.
Actions include moving the state left or right one unit. This
environment is considered more difficult to solve than its
continuous counterpart because of the limited action space.
"""
def __init__(
self,
num_envs: int,
/,
horizon: None | int = None,
*,
device: Device = "cpu",
) -> None:
super().__init__(num_envs, horizon, device=device)
self.action_spec = Categorical(2, shape=torch.Size([1]), device=device)
[docs] def step(self, action: torch.Tensor) -> TensorDict:
self.state += 2 * action - 1
return TensorDict(
{DataKeys.OBS: self.state, DataKeys.REWARDS: -self.state.abs()},
batch_size=self.num_envs,
device=self.device,
)