Wrappers
Jaxatari Wrappers
- class jaxatari.wrappers.AtariState(env_state: EnvState, key: jax.Array, step: int, prev_action: int, obs_stack: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number)[source]
Bases:
object- env_state: EnvState
- key: Array
- obs_stack: Array | ndarray | bool_ | number
- prev_action: int
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- step: int
- class jaxatari.wrappers.AtariWrapper(env, sticky_actions: bool = True, frame_stack_size: int = 4, frame_skip: int = 4, max_episode_length: int = 10000, episodic_life: bool = True, first_fire: bool = True, noop_reset: int = 0, clip_reward: bool = False, max_pooling: bool = False, full_action_space: bool = False)[source]
Bases:
JaxatariWrapperWrapper for Atari environments that returns the rendered image and object-centric observations unflattened. Both are stacked by frame_stack_size. :param env: The environment to wrap. :param sticky_actions: Whether to use sticky actions. :param frame_stack_size: The number of frames to stack. :param frame_skip: The number of frames to skip.
- step(state: AtariState, action: int | float) Tuple[Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number], AtariState, float, bool, Dict[Any, Any]][source]
- class jaxatari.wrappers.FlattenObservationWrapper(env)[source]
Bases:
JaxatariWrapperA wrapper that flattens each leaf array in an observation Pytree.
Compatible with all the other wrappers, flattens the observations whilst preserving the overarching structure (i.e. if the observation is a tuple of multiple observations, the flattened observation will be a tuple of flattened observations).
- class jaxatari.wrappers.JaxatariWrapper(env)[source]
Bases:
objectBase class for JAXAtark wrappers.
- class jaxatari.wrappers.LogState(atari_state: Any, episode_returns: float, episode_lengths: int, returned_episode_returns: float, returned_episode_lengths: int)[source]
Bases:
object- atari_state: Any
- episode_lengths: int
- episode_returns: float
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- returned_episode_lengths: int
- returned_episode_returns: float
- class jaxatari.wrappers.LogWrapper(env)[source]
Bases:
JaxatariWrapperLog the episode returns and lengths.
- class jaxatari.wrappers.MultiRewardLogState(atari_state: Any, episode_returns_env: float, episode_returns: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number, episode_lengths: int, returned_episode_returns_env: float, returned_episode_returns: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number, returned_episode_lengths: int)[source]
Bases:
object- atari_state: Any
- episode_lengths: int
- episode_returns: Array | ndarray | bool_ | number
- episode_returns_env: float
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- returned_episode_lengths: int
- returned_episode_returns: Array | ndarray | bool_ | number
- returned_episode_returns_env: float
- class jaxatari.wrappers.MultiRewardLogWrapper(env)[source]
Bases:
JaxatariWrapperLog the episode returns and lengths for multiple rewards. Make sure to apply MultiRewardWrapper to the core env when using this wrapper. The final logs will be ‘returned_episode_returns_0’, … for each reward function provided.
- reset(key: Array) Tuple[Array | ndarray | bool_ | number, MultiRewardLogState][source]
- step(state: MultiRewardLogState, action: int | float) Tuple[Array | ndarray | bool_ | number, MultiRewardLogState, float, bool, Dict[Any, Any]][source]
- class jaxatari.wrappers.MultiRewardWrapper(env, reward_funcs: list[Callable])[source]
Bases:
JaxatariWrapperAllows providing multiple reward functions to be computed at every step. Apply this wrapper directly after the base environment, before any other wrappers.
- class jaxatari.wrappers.NormalizeObservationWrapper(env, to_neg_one: bool = False, dtype=<class 'jax.numpy.float16'>)[source]
Bases:
JaxatariWrapperA wrapper that normalizes each leaf in an observation Pytree. This wrapper is compatible with any observation structure (Pytrees).
- observation_space() Space[source]
Returns the normalized observation space where leaves are in [0, 1].
- class jaxatari.wrappers.ObjectCentricWrapper(env)[source]
Bases:
JaxatariWrapperWrapper for Atari environments that returns stacked object-centric observations. The output observation is a 2D array of shape (frame_stack_size, num_features). Apply this wrapper after the AtariWrapper!
- step(state: AtariState, action: int | float) Tuple[Array | ndarray | bool_ | number, EnvState, float, bool, Any][source]
- class jaxatari.wrappers.PixelAndObjectCentricState(atari_state: jaxatari.wrappers.AtariState, image_stack: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number, obs_stack: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number)[source]
Bases:
object- atari_state: AtariState
- image_stack: Array | ndarray | bool_ | number
- obs_stack: Array | ndarray | bool_ | number
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class jaxatari.wrappers.PixelAndObjectCentricWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False)[source]
Bases:
JaxatariWrapperWrapper for Atari environments that returns the flattened pixel observations and object-centric observations. Apply this wrapper after the AtariWrapper!
- observation_space() Tuple[source]
Returns a Tuple space containing stacked image and object spaces.
- step(state: PixelAndObjectCentricState, action: int | float) Tuple[Array | ndarray | bool_ | number, EnvState, float, bool, Any][source]
- class jaxatari.wrappers.PixelAndObjectObsWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False)[source]
Bases:
PixelAndObjectCentricWrapperExactly the same as PixelAndObjectCentricWrapper, but return structured OC-obs instead of flattened array.
- step(state: PixelAndObjectCentricState, action: int | float) Tuple[Array | ndarray | bool_ | number, EnvState, float, bool, Any][source]
- class jaxatari.wrappers.PixelObsWrapper(env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False)[source]
Bases:
JaxatariWrapperWrapper for Atari environments that returns the flattened pixel observations. Apply this wrapper after the AtariWrapper!
- reset(key: Array) Tuple[Array | ndarray | bool_ | number, PixelState][source]
- step(state: PixelState, action: int | float) Tuple[Array | ndarray | bool_ | number, EnvState, float, bool, Any][source]
- class jaxatari.wrappers.PixelState(atari_state: jaxatari.wrappers.AtariState, image_stack: jax.Array | numpy.ndarray | numpy.bool_ | numpy.number)[source]
Bases:
object- atari_state: AtariState
- image_stack: Array | ndarray | bool_ | number
- replace(**updates)
Returns a new object replacing the specified fields with new values.