Wrappers¶
Jaxatari Wrappers
- class jaxatari.wrappers.AtariState(env_state: Any, key: jax.Array, step: int, prev_action: int)[source]¶
Bases:
object- env_state: Any¶
- key: Array¶
- prev_action: int¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- step: int¶
- class jaxatari.wrappers.AtariWrapper(env, sticky_actions: float = 0.25, episodic_life: bool = True, first_fire: bool = True, noop_max: int = 30, full_action_space: bool = False, max_frames_per_episode: int = 108000)[source]¶
Bases:
JaxatariWrapperWrapper for Atari environments that applies Atari-specific control logic. Returns single-step, single-frame observations from the wrapped base env. :param env: The environment to wrap. :param sticky_actions: Sticky action probability in [0, 1]. Defaults to 0.25. :param episodic_life: Loss of life -> terminated. Does not reset the environment. Defaults to True. :param first_fire: Take FIRE action on reset. Defaults to True. :param noop_max: Max number of no-op actions to take on reset. Defaults to 30. :param full_action_space: Use full action space of 18 actions. Defaults to False (minimal action set). :param max_frames_per_episode: Maximum number of frames per episode before truncation. Defaults to 108,000 (30 minutes at 60fps).
Note: Typically, this wrapper is followed by PixelObsWrapper, ObjectCentricWrapper or PixelAndObjectCentricWrapper. Frame-skipping, max-pooling, frame-stacking and reward clipping are handled in those.
- reset(key: Array) Tuple[Array | ndarray | bool_ | number, AtariState][source]¶
- step(state: AtariState, action: int) Tuple[Array | ndarray | bool_ | number, AtariState, float, bool, bool, Dict[Any, Any]][source]¶
- class jaxatari.wrappers.FlattenObservationWrapper(env)[source]¶
Bases:
JaxatariWrapperA wrapper that flattens each leaf array in an observation Pytree.
Compatible with all the other wrappers, flattens the observations whilst preserving the overarching structure (i.e. if the observation is a tuple of multiple observations, the flattened observation will be a tuple of flattened observations).
- class jaxatari.wrappers.JaxatariWrapper(env)[source]¶
Bases:
objectBase class for JAXAtark wrappers.
- class jaxatari.wrappers.LogState(atari_state: Any, episode_returns: float, episode_lengths: int, returned_episode_returns: float, returned_episode_lengths: int)[source]¶
Bases:
object- atari_state: Any¶
- episode_lengths: int¶
- episode_returns: float¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- returned_episode_lengths: int¶
- returned_episode_returns: float¶
- class jaxatari.wrappers.LogWrapper(env)[source]¶
Bases:
JaxatariWrapperLog episode returns and lengths. An episode ends when the wrapped env returns done=True. Uses env_reward from info when present (unclipped); otherwise uses the step reward.
- class jaxatari.wrappers.MultiRewardLogState(atari_state: Any, episode_returns_env: float, episode_returns: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number, episode_lengths: int, returned_episode_returns_env: float, returned_episode_returns: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number, returned_episode_lengths: int)[source]¶
Bases:
object- atari_state: Any¶
- episode_lengths: int¶
- episode_returns: Array | ndarray | bool_ | number¶
- episode_returns_env: float¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- returned_episode_lengths: int¶
- returned_episode_returns: Array | ndarray | bool_ | number¶
- returned_episode_returns_env: float¶
- class jaxatari.wrappers.MultiRewardLogWrapper(env)[source]¶
Bases:
JaxatariWrapperLog episode returns and lengths for multiple rewards. An episode ends when the wrapped env returns done=True. Apply MultiRewardWrapper to the core env when using this wrapper. Final logs: ‘returned_episode_returns_0’, … for each reward function; env reward in ‘returned_episode_env_returns’.
- reset(key: Array) Tuple[Array | ndarray | bool_ | number, MultiRewardLogState][source]¶
- step(state: MultiRewardLogState, action: int) Tuple[Array | ndarray | bool_ | number, MultiRewardLogState, float, bool, bool, Dict[Any, Any]][source]¶
- class jaxatari.wrappers.MultiRewardWrapper(env, reward_funcs: list[Callable])[source]¶
Bases:
JaxatariWrapperAllows providing multiple reward functions to be computed at every step. Apply this wrapper directly after the base environment, before any other wrappers.
- class jaxatari.wrappers.NormalizeObservationWrapper(env, to_neg_one: bool = False, dtype=<class 'jax.numpy.float16'>)[source]¶
Bases:
JaxatariWrapperA wrapper that normalizes each leaf in an observation Pytree. This wrapper is compatible with any observation structure (Pytrees).
- observation_space() Space[source]¶
Returns the normalized observation space where leaves are in [0, 1].
- class jaxatari.wrappers.ObjectCentricState(atari_state: jaxatari.wrappers.AtariState, obs_stack: jax.Array)[source]¶
Bases:
object- atari_state: AtariState¶
- obs_stack: Array¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- class jaxatari.wrappers.ObjectCentricWrapper(env, frame_stack_size: int = 4, frame_skip: int = 4, clip_reward: bool = True)[source]¶
Bases:
JaxatariWrapperWrapper for Atari environments that returns stacked object-centric observations. The output observation is a 2D array of shape (frame_stack_size, num_features). Apply this wrapper after the AtariWrapper!
- reset(key: Array) Tuple[Array | ndarray | bool_ | number, ObjectCentricState][source]¶
- step(state: ObjectCentricState, action: int) Tuple[Array | ndarray | bool_ | number, ObjectCentricState, float, bool, bool, Dict[Any, Any]][source]¶
- class jaxatari.wrappers.PixelAndObjectCentricState(atari_state: jaxatari.wrappers.AtariState, image_stack: jax.Array, obs_stack: Any)[source]¶
Bases:
object- atari_state: AtariState¶
- image_stack: Array¶
- obs_stack: Any¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- class jaxatari.wrappers.PixelAndObjectCentricWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True)[source]¶
Bases:
JaxatariWrapperWrapper for Atari environments that returns the flattened pixel observations and object-centric observations. Apply this wrapper after the AtariWrapper!
- observation_space() Tuple[source]¶
Returns a Tuple space containing stacked image and object spaces.
- reset(key: Array) Tuple[Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number], PixelAndObjectCentricState][source]¶
- step(state: PixelAndObjectCentricState, action: int) Tuple[Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number], PixelAndObjectCentricState, float, bool, bool, Dict[Any, Any]][source]¶
- class jaxatari.wrappers.PixelAndObjectObsWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True)[source]¶
Bases:
PixelAndObjectCentricWrapperExactly the same as PixelAndObjectCentricWrapper, but return structured OC-obs instead of flattened array.
- reset(key: Array) Tuple[Tuple[Array | ndarray | bool_ | number, Any], PixelAndObjectCentricState][source]¶
- step(state: PixelAndObjectCentricState, action: int) Tuple[Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number], PixelAndObjectCentricState, float, bool, bool, Dict[Any, Any]][source]¶
- class jaxatari.wrappers.PixelObsWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True)[source]¶
Bases:
JaxatariWrapperWrapper for Atari environments that returns the flattened pixel observations. Apply this wrapper after the AtariWrapper!
- reset(key: Array) Tuple[Array | ndarray | bool_ | number, PixelState][source]¶
- step(state: PixelState, action: int) Tuple[Array | ndarray | bool_ | number, PixelState, float, bool, bool, Dict[Any, Any]][source]¶
- class jaxatari.wrappers.PixelState(atari_state: jaxatari.wrappers.AtariState, image_stack: jax.Array)[source]¶
Bases:
object- atari_state: AtariState¶
- image_stack: Array¶
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- jaxatari.wrappers.preprocess_image(class_instance: JaxatariWrapper, image: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]¶
Applies resizing and grayscaling to a single image frame.