Base Environment¶
- class jaxatari.environment.JAXAtariAction[source]¶
Bases:
object“Namespace” for Atari action integer constants. These are directly usable in JAX arrays.
- DOWN: int = 5¶
- DOWNFIRE: int = 13¶
- DOWNLEFT: int = 9¶
- DOWNLEFTFIRE: int = 17¶
- DOWNRIGHT: int = 8¶
- DOWNRIGHTFIRE: int = 16¶
- FIRE: int = 1¶
- LEFT: int = 4¶
- LEFTFIRE: int = 12¶
- NOOP: int = 0¶
- RIGHT: int = 3¶
- RIGHTFIRE: int = 11¶
- UP: int = 2¶
- UPFIRE: int = 10¶
- UPLEFT: int = 7¶
- UPLEFTFIRE: int = 15¶
- UPRIGHT: int = 6¶
- UPRIGHTFIRE: int = 14¶
- class jaxatari.environment.JaxEnvironment(consts: EnvConstants | None = None)[source]¶
Bases:
Generic[EnvState,EnvObs,EnvInfo,EnvConstants]Abstract class for a JAX environment. Generics: EnvState: The type of the environment state. EnvObs: The type of the observation. EnvInfo: The type of the additional information. EnvConstants: The type of the environment constants.
- action_space() Space[source]¶
Returns the action space of the environment as an array containing the actions that can be taken. Returns: The action space of the environment as an array.
- image_space() Space[source]¶
Returns the image space of the environment. Returns: The image space of the environment.
- observation_space() Space[source]¶
Returns the observation space of the environment. Returns: The observation space of the environment.
- render(state: EnvState) Tuple[Array][source]¶
Renders the environment state to a single image. :param state: The environment state.
Returns: A single image of the environment state.
- reset(key: PRNGKey | None = None) Tuple[EnvObs, EnvState][source]¶
Resets the environment to the initial state. Returns: The initial observation and the initial environment state.
- step(state: EnvState, action) Tuple[EnvObs, EnvState, float, bool, EnvInfo][source]¶
Takes a step in the environment. :param state: The current environment state. :param action: The action to take.
Returns: The observation, the new environment state, the reward, whether the state is terminal, and additional info.
- class jaxatari.environment.ObjectObservation(x: ~jax.Array, y: ~jax.Array, width: ~jax.Array, height: ~jax.Array, active: ~jax.Array = <factory>, visual_id: ~jax.Array = <factory>, state: ~jax.Array = <factory>, orientation: ~jax.Array = <factory>)[source]¶
Bases:
objectDataclass for object centric observations of objects in jaxatari environments. Can hold 1 to N objects of the same type (for example 12 sharks in seaquest or 1 player ship in asteroids). Should always be instantiated via the create() classmethod to ensure proper default handling. .. attribute:: x
x position of the object.
- type:
jax.Array
- y¶
y position of the object.
- Type:
jax.Array
- width¶
width of the object.
- Type:
jax.Array
- height¶
height of the object.
- Type:
jax.Array
- active¶
whether the object is currently active.
- Type:
jax.Array
- active: Array¶
- classmethod create(x, y, width, height, active=None, visual_id=None, state=None, orientation=None)[source]¶
- height: Array¶
- orientation: Array¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- state: Array¶
- visual_id: Array¶
- width: Array¶
- x: Array¶
- y: Array¶