Wrappers

Jaxatari Wrappers

class jaxatari.wrappers.AtariState(env_state: Any, key: jax.Array, step: int, prev_action: int)[source]

Bases: object

env_state: Any
key: Array
prev_action: int
replace(**updates)

Returns a new object replacing the specified fields with new values.

step: int
class jaxatari.wrappers.AtariWrapper(env, sticky_actions: float = 0.25, episodic_life: bool = True, first_fire: bool = True, noop_max: int = 30, full_action_space: bool = False, max_frames_per_episode: int = 108000)[source]

Bases: JaxatariWrapper

Wrapper for Atari environments that applies Atari-specific control logic. Returns single-step, single-frame observations from the wrapped base env. :param env: The environment to wrap. :param sticky_actions: Sticky action probability in [0, 1]. Defaults to 0.25. :param episodic_life: Loss of life -> terminated. Does not reset the environment. Defaults to True. :param first_fire: Take FIRE action on reset. Defaults to True. :param noop_max: Max number of no-op actions to take on reset. Defaults to 30. :param full_action_space: Use full action space of 18 actions. Defaults to False (minimal action set). :param max_frames_per_episode: Maximum number of frames per episode before truncation. Defaults to 108,000 (30 minutes at 60fps).

Note: Typically, this wrapper is followed by PixelObsWrapper, ObjectCentricWrapper or PixelAndObjectCentricWrapper. Frame-skipping, max-pooling, frame-stacking and reward clipping are handled in those.

image_space() Box[source]

Returns the image space.

observation_space() Space[source]

Returns the single-frame base observation space.

reset(key: Array) Tuple[Array | ndarray | bool_ | number, AtariState][source]
step(state: AtariState, action: int) Tuple[Array | ndarray | bool_ | number, AtariState, float, bool, bool, Dict[Any, Any]][source]
class jaxatari.wrappers.FlattenObservationWrapper(env)[source]

Bases: JaxatariWrapper

A wrapper that flattens each leaf array in an observation Pytree.

Compatible with all the other wrappers, flattens the observations whilst preserving the overarching structure (i.e. if the observation is a tuple of multiple observations, the flattened observation will be a tuple of flattened observations).

observation_space() Space[source]

Returns a space where each leaf array is flattened.

reset(key: Array) Tuple[Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], Any][source]
step(state: Any, action: int) Tuple[Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], Any, float, bool, bool, Dict[str, Any]][source]
class jaxatari.wrappers.JaxatariWrapper(env)[source]

Bases: object

Base class for JAXAtark wrappers.

class jaxatari.wrappers.LogState(atari_state: Any, episode_returns: float, episode_lengths: int, returned_episode_returns: float, returned_episode_lengths: int)[source]

Bases: object

atari_state: Any
episode_lengths: int
episode_returns: float
replace(**updates)

Returns a new object replacing the specified fields with new values.

returned_episode_lengths: int
returned_episode_returns: float
class jaxatari.wrappers.LogWrapper(env)[source]

Bases: JaxatariWrapper

Log episode returns and lengths. An episode ends when the wrapped env returns done=True. Uses env_reward from info when present (unclipped); otherwise uses the step reward.

reset(key: Array) Tuple[Array | ndarray | bool_ | number, LogState][source]
step(state: LogState, action: int) Tuple[Array | ndarray | bool_ | number, LogState, float, bool, bool, Dict[Any, Any]][source]
class jaxatari.wrappers.MultiRewardLogState(atari_state: Any, episode_returns_env: float, episode_returns: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number, episode_lengths: int, returned_episode_returns_env: float, returned_episode_returns: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number, returned_episode_lengths: int)[source]

Bases: object

atari_state: Any
episode_lengths: int
episode_returns: Array | ndarray | bool_ | number
episode_returns_env: float
replace(**updates)

Returns a new object replacing the specified fields with new values.

returned_episode_lengths: int
returned_episode_returns: Array | ndarray | bool_ | number
returned_episode_returns_env: float
class jaxatari.wrappers.MultiRewardLogWrapper(env)[source]

Bases: JaxatariWrapper

Log episode returns and lengths for multiple rewards. An episode ends when the wrapped env returns done=True. Apply MultiRewardWrapper to the core env when using this wrapper. Final logs: ‘returned_episode_returns_0’, … for each reward function; env reward in ‘returned_episode_env_returns’.

reset(key: Array) Tuple[Array | ndarray | bool_ | number, MultiRewardLogState][source]
step(state: MultiRewardLogState, action: int) Tuple[Array | ndarray | bool_ | number, MultiRewardLogState, float, bool, bool, Dict[Any, Any]][source]
class jaxatari.wrappers.MultiRewardWrapper(env, reward_funcs: list[Callable])[source]

Bases: JaxatariWrapper

Allows providing multiple reward functions to be computed at every step. Apply this wrapper directly after the base environment, before any other wrappers.

step(state: EnvState, action: int) Tuple[Array | ndarray | bool_ | number, EnvState, float, bool, Dict][source]
class jaxatari.wrappers.NormalizeObservationWrapper(env, to_neg_one: bool = False, dtype=<class 'jax.numpy.float16'>)[source]

Bases: JaxatariWrapper

A wrapper that normalizes each leaf in an observation Pytree. This wrapper is compatible with any observation structure (Pytrees).

observation_space() Space[source]

Returns the normalized observation space where leaves are in [0, 1].

reset(key: Array) Tuple[Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], Any][source]
step(state: Any, action: int) Tuple[Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree], Any, float, bool, bool, Dict[str, Any]][source]
class jaxatari.wrappers.ObjectCentricState(atari_state: jaxatari.wrappers.AtariState, obs_stack: jax.Array)[source]

Bases: object

atari_state: AtariState
obs_stack: Array
replace(**updates)

Returns a new object replacing the specified fields with new values.

class jaxatari.wrappers.ObjectCentricWrapper(env, frame_stack_size: int = 4, frame_skip: int = 4, clip_reward: bool = True)[source]

Bases: JaxatariWrapper

Wrapper for Atari environments that returns stacked object-centric observations. The output observation is a 2D array of shape (frame_stack_size, num_features). Apply this wrapper after the AtariWrapper!

observation_space() Box[source]

Returns a Box space for the flattened observation.

reset(key: Array) Tuple[Array | ndarray | bool_ | number, ObjectCentricState][source]
step(state: ObjectCentricState, action: int) Tuple[Array | ndarray | bool_ | number, ObjectCentricState, float, bool, bool, Dict[Any, Any]][source]
class jaxatari.wrappers.PixelAndObjectCentricState(atari_state: jaxatari.wrappers.AtariState, image_stack: jax.Array, obs_stack: Any)[source]

Bases: object

atari_state: AtariState
image_stack: Array
obs_stack: Any
replace(**updates)

Returns a new object replacing the specified fields with new values.

class jaxatari.wrappers.PixelAndObjectCentricWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True)[source]

Bases: JaxatariWrapper

Wrapper for Atari environments that returns the flattened pixel observations and object-centric observations. Apply this wrapper after the AtariWrapper!

observation_space() Tuple[source]

Returns a Tuple space containing stacked image and object spaces.

reset(key: Array) Tuple[Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number], PixelAndObjectCentricState][source]
step(state: PixelAndObjectCentricState, action: int) Tuple[Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number], PixelAndObjectCentricState, float, bool, bool, Dict[Any, Any]][source]
class jaxatari.wrappers.PixelAndObjectObsWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True)[source]

Bases: PixelAndObjectCentricWrapper

Exactly the same as PixelAndObjectCentricWrapper, but return structured OC-obs instead of flattened array.

reset(key: Array) Tuple[Tuple[Array | ndarray | bool_ | number, Any], PixelAndObjectCentricState][source]
step(state: PixelAndObjectCentricState, action: int) Tuple[Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number], PixelAndObjectCentricState, float, bool, bool, Dict[Any, Any]][source]
class jaxatari.wrappers.PixelObsWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True)[source]

Bases: JaxatariWrapper

Wrapper for Atari environments that returns the flattened pixel observations. Apply this wrapper after the AtariWrapper!

observation_space() Space[source]

Returns the stacked image space.

reset(key: Array) Tuple[Array | ndarray | bool_ | number, PixelState][source]
step(state: PixelState, action: int) Tuple[Array | ndarray | bool_ | number, PixelState, float, bool, bool, Dict[Any, Any]][source]
class jaxatari.wrappers.PixelState(atari_state: jaxatari.wrappers.AtariState, image_stack: jax.Array)[source]

Bases: object

atari_state: AtariState
image_stack: Array
replace(**updates)

Returns a new object replacing the specified fields with new values.

jaxatari.wrappers.preprocess_image(class_instance: JaxatariWrapper, image: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]

Applies resizing and grayscaling to a single image frame.