Source code for jaxatari.environment

from enum import Enum
from typing import Tuple, Generic, TypeVar
import jax.numpy as jnp
import jax.random as jrandom
from jaxatari.spaces import Space


EnvObs = TypeVar("EnvObs")
EnvState = TypeVar("EnvState")
EnvInfo = TypeVar("EnvInfo")
EnvConstants = TypeVar("EnvConstants")

[docs] class JAXAtariAction: """ "Namespace" for Atari action integer constants. These are directly usable in JAX arrays. """ NOOP: int = 0 FIRE: int = 1 UP: int = 2 RIGHT: int = 3 LEFT: int = 4 DOWN: int = 5 UPRIGHT: int = 6 UPLEFT: int = 7 DOWNRIGHT: int = 8 DOWNLEFT: int = 9 UPFIRE: int = 10 RIGHTFIRE: int = 11 LEFTFIRE: int = 12 DOWNFIRE: int = 13 UPRIGHTFIRE: int = 14 UPLEFTFIRE: int = 15 DOWNRIGHTFIRE: int = 16 DOWNLEFTFIRE: int = 17
[docs] @classmethod def get_all_values(cls) -> jnp.ndarray: # For fixed action sets, explicit listing is safest and clearest. return jnp.array([ cls.NOOP, cls.FIRE, cls.UP, cls.RIGHT, cls.LEFT, cls.DOWN, cls.UPRIGHT, cls.UPLEFT, cls.DOWNRIGHT, cls.DOWNLEFT, cls.UPFIRE, cls.RIGHTFIRE, cls.LEFTFIRE, cls.DOWNFIRE, cls.UPRIGHTFIRE, cls.UPLEFTFIRE, cls.DOWNRIGHTFIRE, cls.DOWNLEFTFIRE ], dtype=jnp.int32)
[docs] class JaxEnvironment(Generic[EnvState, EnvObs, EnvInfo, EnvConstants]): """ Abstract class for a JAX environment. Generics: EnvState: The type of the environment state. EnvObs: The type of the observation. EnvInfo: The type of the additional information. EnvConstants: The type of the environment constants. """ def __init__(self, consts: EnvConstants = None): self.consts = consts
[docs] def reset(self, key: jrandom.PRNGKey=None) -> Tuple[EnvObs, EnvState]: """ Resets the environment to the initial state. Returns: The initial observation and the initial environment state. """ raise NotImplementedError("Abstract method")
[docs] def step( self, state: EnvState, action ) -> Tuple[EnvObs, EnvState, float, bool, EnvInfo]: """ Takes a step in the environment. Args: state: The current environment state. action: The action to take. Returns: The observation, the new environment state, the reward, whether the state is terminal, and additional info. """ raise NotImplementedError("Abstract method")
[docs] def render(self, state: EnvState) -> Tuple[jnp.ndarray]: """ Renders the environment state to a single image. Args: state: The environment state. Returns: A single image of the environment state. """ raise NotImplementedError("Abstract method")
[docs] def action_space(self) -> Space: """ Returns the action space of the environment as an array containing the actions that can be taken. Returns: The action space of the environment as an array. """ raise NotImplementedError("Abstract method")
[docs] def observation_space(self) -> Space: """ Returns the observation space of the environment. Returns: The observation space of the environment. """ raise NotImplementedError("Abstract method")
[docs] def image_space(self) -> Space: """ Returns the image space of the environment. Returns: The image space of the environment. """ raise NotImplementedError("Abstract method")
def _get_observation(self, state: EnvState) -> EnvObs: """ Converts the environment state to the observation by filtering out non-relevant information. Args: state: The environment state. Returns: observation """ raise NotImplementedError("Abstract method")
[docs] def obs_to_flat_array(self, obs: EnvObs) -> jnp.ndarray: """ Converts the observation to a flat array. Args: obs: The observation. """ raise NotImplementedError("Abstract method")
def _get_info(self, state: EnvState, all_rewards: jnp.array = None) -> EnvInfo: """ Extracts information from the environment state that is not relevant for the agent. Args: state: The environment state. Returns: info """ raise NotImplementedError("Abstract method") def _get_reward(self, previous_state: EnvState, state: EnvState) -> float: """ Calculates the reward from the environment state. Args: previous_state: The previous environment state. state: The environment state. Returns: reward """ raise NotImplementedError("Abstract method") def _get_done(self, state: EnvState) -> bool: """ Determines if the environment state is a terminal state Args: state: The environment state. Returns: True if the state is terminal, False otherwise. """ raise NotImplementedError("Abstract method")