Base Environment

class jaxatari.environment.JAXAtariAction[source]

Bases: object

“Namespace” for Atari action integer constants. These are directly usable in JAX arrays.

DOWN: int = 5
DOWNFIRE: int = 13
DOWNLEFT: int = 9
DOWNLEFTFIRE: int = 17
DOWNRIGHT: int = 8
DOWNRIGHTFIRE: int = 16
FIRE: int = 1
LEFT: int = 4
LEFTFIRE: int = 12
NOOP: int = 0
RIGHT: int = 3
RIGHTFIRE: int = 11
UP: int = 2
UPFIRE: int = 10
UPLEFT: int = 7
UPLEFTFIRE: int = 15
UPRIGHT: int = 6
UPRIGHTFIRE: int = 14
classmethod get_all_values() Array[source]
class jaxatari.environment.JaxEnvironment(consts: EnvConstants | None = None)[source]

Bases: Generic[EnvState, EnvObs, EnvInfo, EnvConstants]

Abstract class for a JAX environment. Generics: EnvState: The type of the environment state. EnvObs: The type of the observation. EnvInfo: The type of the additional information. EnvConstants: The type of the environment constants.

action_space() Space[source]

Returns the action space of the environment as an array containing the actions that can be taken. Returns: The action space of the environment as an array.

image_space() Space[source]

Returns the image space of the environment. Returns: The image space of the environment.

observation_space() Space[source]

Returns the observation space of the environment. Returns: The observation space of the environment.

render(state: EnvState) Tuple[Array][source]

Renders the environment state to a single image. :param state: The environment state.

Returns: A single image of the environment state.

reset(key: PRNGKey | None = None) Tuple[EnvObs, EnvState][source]

Resets the environment to the initial state. Returns: The initial observation and the initial environment state.

step(state: EnvState, action) Tuple[EnvObs, EnvState, float, bool, EnvInfo][source]

Takes a step in the environment. :param state: The current environment state. :param action: The action to take.

Returns: The observation, the new environment state, the reward, whether the state is terminal, and additional info.

class jaxatari.environment.ObjectObservation(x: ~jax.Array, y: ~jax.Array, width: ~jax.Array, height: ~jax.Array, active: ~jax.Array = <factory>, visual_id: ~jax.Array = <factory>, state: ~jax.Array = <factory>, orientation: ~jax.Array = <factory>)[source]

Bases: object

Dataclass for object centric observations of objects in jaxatari environments. Can hold 1 to N objects of the same type (for example 12 sharks in seaquest or 1 player ship in asteroids). Should always be instantiated via the create() classmethod to ensure proper default handling. .. attribute:: x

x position of the object.

type:

jax.Array

y

y position of the object.

Type:

jax.Array

width

width of the object.

Type:

jax.Array

height

height of the object.

Type:

jax.Array

active

whether the object is currently active.

Type:

jax.Array

active: Array
classmethod create(x, y, width, height, active=None, visual_id=None, state=None, orientation=None)[source]
height: Array
orientation: Array
replace(**updates)

Returns a new object replacing the specified fields with new values.

state: Array
visual_id: Array
width: Array
x: Array
y: Array