Source code for jaxatari.wrappers

"""Jaxatari Wrappers"""
from absl.logging import info

import functools
import types
import warnings
from typing import Any, Dict, Tuple, Union, Optional, Callable
from dataclasses import is_dataclass, asdict

import chex
from flax import struct
import jax
import jax.image as jim
import jax.numpy as jnp
from jax import flatten_util
from jaxatari.environment import EnvState, JAXAtariAction as Action
import jaxatari.spaces as spaces
import numpy as np
from jaxatari.rendering.jax_rendering_utils import RendererConfig

[docs] class JaxatariWrapper(object): """Base class for JAXAtark wrappers.""" def __init__(self, env): self._env = env # provide proxy access to regular attributes of wrapped object def __getattr__(self, name): return getattr(self._env, name)
[docs] class MultiRewardWrapper(JaxatariWrapper): """ Allows providing multiple reward functions to be computed at every step. Apply this wrapper directly after the base environment, before any other wrappers. """ def __init__(self, env, reward_funcs: list[Callable]): super().__init__(env) assert isinstance(reward_funcs, list) and len(reward_funcs) > 0, "reward_funcs must be a non-empty list of callables" self._reward_funcs = reward_funcs @functools.partial(jax.jit, static_argnums=(0,)) def _get_all_rewards(self, previous_state: EnvState, state: EnvState) -> chex.Array: """Compute multiple rewards based on the provided reward functions.""" if self._reward_funcs is None: return jnp.zeros(1) rewards = jnp.array( [reward_func(previous_state, state) for reward_func in self._reward_funcs] ) return rewards
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step(self, state: EnvState, action: int) -> Tuple[chex.Array, EnvState, float, bool, Dict]: obs, new_state, reward, done, info = self._env.step(state, action) all_rewards = self._get_all_rewards(state, new_state) # Convert info to dict: handle NamedTuple (has _asdict) or dataclass (use asdict) if hasattr(info, '_asdict'): info = info._asdict() elif is_dataclass(info): info = asdict(info) info["all_rewards"] = all_rewards return obs, new_state, reward, done, info
[docs] @struct.dataclass class AtariState: env_state: Any key: chex.PRNGKey step: int prev_action: int
[docs] class AtariWrapper(JaxatariWrapper): """ Wrapper for Atari environments that applies Atari-specific control logic. Returns single-step, single-frame observations from the wrapped base env. Args: env: The environment to wrap. sticky_actions: Sticky action probability in [0, 1]. Defaults to 0.25. episodic_life: Loss of life -> terminated. Does not reset the environment. Defaults to True. first_fire: Take FIRE action on reset. Defaults to True. noop_max: Max number of no-op actions to take on reset. Defaults to 30. full_action_space: Use full action space of 18 actions. Defaults to False (minimal action set). max_frames_per_episode: Maximum number of frames per episode before truncation. Defaults to 108,000 (30 minutes at 60fps). Note: Typically, this wrapper is followed by PixelObsWrapper, ObjectCentricWrapper or PixelAndObjectCentricWrapper. Frame-skipping, max-pooling, frame-stacking and reward clipping are handled in those. """ def __init__(self, env, sticky_actions: float = 0.25, episodic_life: bool = True, first_fire: bool = True, noop_max: int = 30, full_action_space: bool = False, max_frames_per_episode: int = 108_000): super().__init__(env) self._env = env self.sticky_actions = float(np.clip(sticky_actions, 0.0, 1.0)) self.episodic_life = episodic_life self.first_fire = first_fire self.noop_reset = False if noop_max == 0 else True self.noop_max = noop_max self.full_action_space = full_action_space self.max_frames_per_episode = max_frames_per_episode # --- 1) HANDLE FULL ACTION SPACE LOGIC --- # If requested, swap the environment's (minimal) action set for the full identity set. # This keeps each game env "clean" while enabling a central switch for experimentation. if self.full_action_space and hasattr(self._env, "ACTION_SET"): # Overwrite the instance attribute with [0, 1, ... 17] self._env.ACTION_SET = jnp.arange(18, dtype=jnp.int32) # --- 2) RESOLVE CORRECT 'FIRE' ACTION INDEX --- # The wrapped env expects an *index* into ACTION_SET (agent action), not the ALE action constant. self.fire_action_index: int = int(Action.FIRE) # fallback if env doesn't expose ACTION_SET self.first_fire = first_fire if hasattr(self._env, "ACTION_SET"): # Convert to numpy for search (safe in __init__) action_set_np = np.array(self._env.ACTION_SET) fire_indices = np.where(action_set_np == int(Action.FIRE))[0] if len(fire_indices) > 0: self.fire_action_index = int(fire_indices[0]) else: # Game has no FIRE action (e.g. Freeway). # Disable first_fire to prevent sending a random command by mistake. self.first_fire = False self._observation_space = self._env.observation_space()
[docs] def observation_space(self) -> spaces.Space: """Returns the single-frame base observation space.""" return self._observation_space
[docs] def image_space(self) -> spaces.Box: """Returns the image space.""" return self._env.image_space()
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, AtariState]: # Split keys for all potential random operations env_key, wrapper_key, noop_key = jax.random.split(key, 3) obs, env_state = self._env.reset(env_key) step = jnp.array(0, dtype=jnp.int32) prev_action = jnp.array(0, dtype=jnp.int32) # ========== NOOP RESET ========== def perform_noop_reset(carry): # This function will be executed if self.noop_reset is True env_state, obs, step = carry # Generate the random number of no-op steps to take. num_noops = jax.random.randint(noop_key, shape=(), minval=0, maxval=self.noop_max + 1) def noop_body_fn(i, loop_carry): current_obs, current_env_state = loop_carry # We always compute the next step for static graph tracing... next_obs, next_env_state, reward, done, info = self._env.step(current_env_state, Action.NOOP) # Might be done from the no-op step, in which case we reset. next_obs, next_env_state = jax.lax.cond( done, lambda: self._env.reset(env_key), lambda: (next_obs, next_env_state) ) return next_obs, next_env_state # Loop for the static maximum number of no-ops. final_obs, final_env_state = jax.lax.fori_loop(0, num_noops, noop_body_fn, (obs, env_state)) # Update the step counter by the dynamic number of no-ops performed. final_step = step + num_noops return final_env_state, final_obs, final_step # Use lax.cond to conditionally apply the whole no-op block based on the static self.noop_reset flag. env_state, obs, step = jax.lax.cond( self.noop_reset, lambda carry: perform_noop_reset(carry), lambda carry: carry, (env_state, obs, step) ) # ========== FIRST FIRE ========== def perform_first_fire(carry): env_state, obs, step, _ = carry fire_obs, fire_env_state, _, _, _ = self._env.step(env_state, self.fire_action_index) return fire_env_state, fire_obs, step + 1, self.fire_action_index def identity_fire(carry): return carry # Conditionally apply the fire action based on the static self.first_fire flag. env_state, obs, step, prev_action = jax.lax.cond( self.first_fire, perform_first_fire, identity_fire, (env_state, obs, step, prev_action) ) return obs, AtariState(env_state, wrapper_key, step, prev_action)
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step(self, state: AtariState, action: int) -> Tuple[chex.Array, AtariState, float, bool, bool, Dict[Any, Any]]: step_key, next_state_key = jax.random.split(state.key) use_sticky_action = jax.random.uniform(step_key, shape=()) < self.sticky_actions new_action = jnp.where(use_sticky_action, state.prev_action, action) obs, new_env_state, reward, env_done, infos = self._env.step(state.env_state, new_action) terminated = env_done if self.episodic_life: # If the player has lost a life, we consider the episode done if hasattr(state.env_state, "lives"): condition = jnp.logical_and(state.env_state.lives > 0, new_env_state.lives < state.env_state.lives) terminated = jnp.logical_or(terminated, condition) elif hasattr(state.env_state, "lives_lost"): terminated = jnp.logical_or(terminated, new_env_state.lives_lost > state.env_state.lives_lost) if hasattr(infos, '_asdict'): # It's a namedtuple or similar, convert to dict info_items = infos._asdict().items() elif is_dataclass(infos): # It's a dataclass, convert to dict info_items = asdict(infos).items() else: # It's already a dict info_items = infos.items() info_dict = {k: v for k, v in info_items} next_state = AtariState(new_env_state, next_state_key, state.step + 1, new_action) # store actual done - not affected by episodic life info_dict["env_done"] = env_done # store actual reward in info dict before clipping info_dict["env_reward"] = reward truncated = (state.step + 1 >= self.max_frames_per_episode) return obs, next_state, reward, terminated, truncated, info_dict
[docs] @struct.dataclass class ObjectCentricState: atari_state: AtariState obs_stack: jax.Array
[docs] class ObjectCentricWrapper(JaxatariWrapper): """ Wrapper for Atari environments that returns stacked object-centric observations. The output observation is a 2D array of shape (frame_stack_size, num_features). Apply this wrapper after the AtariWrapper! """ def __init__(self, env, frame_stack_size: int = 4, frame_skip: int = 4, clip_reward: bool = True): super().__init__(env) assert isinstance(env, AtariWrapper), "ObjectCentricWrapper must be applied after AtariWrapper" self.frame_stack_size = frame_stack_size self.frame_skip = frame_skip self.clip_reward = clip_reward # Calculate exact bounds for the flattened single-frame observation. single_frame_space = self._env.observation_space() lows, highs = [], [] for leaf_space in jax.tree.leaves(single_frame_space): if isinstance(leaf_space, spaces.Box): low_arr = np.broadcast_to(leaf_space.low, leaf_space.shape).flatten() high_arr = np.broadcast_to(leaf_space.high, leaf_space.shape).flatten() lows.append(low_arr) highs.append(high_arr) elif isinstance(leaf_space, spaces.Discrete): lows.append(np.array([0], dtype=leaf_space.dtype)) highs.append(np.array([leaf_space.n - 1], dtype=leaf_space.dtype)) else: raise TypeError(f"Unsupported space type for flattening: {type(leaf_space)}") if not lows: raise ValueError("The observation space appears to be empty or contain unsupported types.") single_frame_lows = np.concatenate(lows) single_frame_highs = np.concatenate(highs) self._observation_space = spaces.Box( low=single_frame_lows, high=single_frame_highs, shape=(self.frame_stack_size, int(single_frame_lows.shape[0])), dtype=jnp.float32 )
[docs] def observation_space(self) -> spaces.Box: """Returns a Box space for the flattened observation.""" return self._observation_space
@functools.partial(jax.jit, static_argnums=(0,)) def _flatten_single_obs(self, obs): """Flatten a single object-centric observation using ravel_pytree.""" return flatten_util.ravel_pytree(obs)[0].astype(jnp.float32)
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset( self, key: chex.PRNGKey ) -> Tuple[chex.Array, ObjectCentricState]: obs, atari_state = self._env.reset(key) flat_obs = self._flatten_single_obs(obs) obs_stack = jnp.stack([flat_obs] * self.frame_stack_size) return obs_stack, ObjectCentricState(atari_state, obs_stack)
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: ObjectCentricState, action: int, ) -> Tuple[chex.Array, ObjectCentricState, float, bool, bool, Dict[Any, Any]]: def body_fn(carry, _): atari_state, action = carry obs, new_atari_state, reward, terminated, truncated, info = self._env.step(atari_state, action) return (new_atari_state, action), (obs, reward, terminated, truncated, info) (atari_state, _), (obs, rewards, terminations, truncations, infos) = jax.lax.scan( body_fn, (state.atari_state, action), None, length=self.frame_skip, ) latest_obs = jax.tree.map(lambda x: x[-1], obs) flat_latest_obs = self._flatten_single_obs(latest_obs) obs_stack = jnp.concatenate([state.obs_stack[1:], jnp.expand_dims(flat_latest_obs, axis=0)], axis=0) reward = jnp.sum(rewards) if self.clip_reward: reward = jnp.sign(reward) terminated = terminations.any() truncated = truncations.any() # Autoreset (gym's SAME_STEP mode) -> reset whole stack obs_stack, oc_state = jax.lax.cond( jnp.logical_or(infos["env_done"].any(), truncated), # use actual env_done for reset condition, not affected by episodic life lambda: self.reset(atari_state.key), # reset if done, using the current key for proper random state advancement lambda: (obs_stack, ObjectCentricState(atari_state, obs_stack)), # step if not done ) def reduce_info(k, v): if k in ["env_reward", "all_rewards"]: return jnp.sum(v, axis=0) if k == "env_done": return jnp.any(v) return v[-1] info_dict = {k: reduce_info(k, v) for k, v in infos.items()} return obs_stack, oc_state, reward, terminated, truncated, info_dict
@functools.partial(jax.jit, static_argnames=('sigma',)) def _gaussian_blur_2d_nchw(image: chex.Array, sigma: float = 3.0) -> chex.Array: """Depthwise separable Gaussian blur for NCHW images (used by preprocess_image).""" # image input: [N, C, H, W] c = image.shape[1] radius = int(sigma * 3) size = radius * 2 + 1 x = jnp.linspace(-radius, radius, size) phi_x = jnp.exp(-0.5 * (x / sigma)**2) phi_x = (phi_x / phi_x.sum()).astype(image.dtype) h_kernel = phi_x[None, None, None, :] h_kernel = jnp.tile(h_kernel, (c, 1, 1, 1)) v_kernel = phi_x[None, None, :, None] v_kernel = jnp.tile(v_kernel, (c, 1, 1, 1)) out = jax.lax.conv_general_dilated( image, h_kernel, (1, 1), padding='SAME', feature_group_count=c, dimension_numbers=('NCHW', 'OIHW', 'NCHW') ) out = jax.lax.conv_general_dilated( out, v_kernel, (1, 1), padding='SAME', feature_group_count=c, dimension_numbers=('NCHW', 'OIHW', 'NCHW') ) return out
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def preprocess_image(class_instance: JaxatariWrapper, image: chex.Array) -> chex.Array: """Applies resizing and grayscaling to a single image frame.""" image = image.astype(jnp.float32) # Has to use a standard Python `if` since jax.lax.cond would fail due to different shapes. This is possible since do_pixel_resize is a static parameter. if class_instance.do_pixel_resize: image = jim.resize(image, (class_instance.pixel_resize_shape[0], class_instance.pixel_resize_shape[1], image.shape[-1]), method='bilinear') # applies grayscale if enabled with the same method as for resize if class_instance.grayscale: image = jnp.dot(image, jnp.array([0.2989, 0.5870, 0.1140]))[..., jnp.newaxis] # numbers for grayscale transformation as in https://en.wikipedia.org/wiki/Luma_(video) # Apply gaussian smoothing to natively downscaled images to get similar effect to actual downscaling if class_instance.native_downscaling and class_instance.smooth_image: image_gauss = _gaussian_blur_2d_nchw(image[None].transpose(0, 3, 1, 2)) image = image_gauss.squeeze().reshape(image.shape) return image.astype(jnp.uint8)
[docs] @struct.dataclass class PixelState: atari_state: AtariState image_stack: jax.Array
[docs] class PixelObsWrapper(JaxatariWrapper): """ Wrapper for Atari environments that returns the flattened pixel observations. Apply this wrapper after the AtariWrapper! """ # TODO: remove do_pixel_resize and resize whenever a different shape / grayscale is given? def __init__(self, env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True): super().__init__(env) assert isinstance(env, AtariWrapper), "PixelObsWrapper has to be applied after AtariWrapper" self.frame_stack_size = frame_stack_size self.frame_skip = frame_skip self.max_pooling = max_pooling self.clip_reward = clip_reward self.smooth_image = smooth_image self.native_downscaling = False # Access the Base Environment base_env = self._env._env if isinstance(self._env, AtariWrapper) else self._env if do_pixel_resize and use_native_downscaling: # call helper from modifications to make sure that applied mods remain applied after native downscaling (lazy import to avoid circular dependency) from jaxatari.modification import apply_native_downscaling self.do_pixel_resize, self.grayscale = apply_native_downscaling( base_env, pixel_resize_shape, grayscale ) self.pixel_resize_shape = pixel_resize_shape self.native_downscaling = True else: self.do_pixel_resize = do_pixel_resize self.pixel_resize_shape = pixel_resize_shape self.grayscale = grayscale # Dynamically calculate the final observation space shape # If we hot-swapped, image_space() will now return the correct small size automatically final_shape = self._env.image_space().shape # If we are doing wrapper-side resizing (legacy), we still calculate manually if self.do_pixel_resize: height, width = self.pixel_resize_shape channels = 1 if self.grayscale else final_shape[2] final_shape = (height, width, channels) # Create the space for a single preprocessed frame image_space = spaces.Box(low=0, high=255, shape=final_shape, dtype=jnp.uint8) # Stack the single-frame space self._observation_space = spaces.stack_space(image_space, self.frame_stack_size)
[docs] def observation_space(self) -> spaces.Space: """Returns the stacked image space.""" return self._observation_space
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, PixelState]: _, atari_state = self._env.reset(key) image = self._env.render(atari_state.env_state) processed_image = preprocess_image(self, image) image_stack = jnp.stack([processed_image] * self.frame_stack_size) return image_stack, PixelState(atari_state, image_stack)
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: PixelState, action: int, ) -> Tuple[chex.Array, PixelState, float, bool, bool, Dict[Any, Any]]: def body_fn(carry, _): atari_state, action = carry _, new_atari_state, reward, terminated, truncated, info = self._env.step(atari_state, action) return (new_atari_state, action), (new_atari_state.env_state, reward, terminated, truncated, info) (atari_state, _), (env_states, rewards, terminations, truncations, infos) = jax.lax.scan( body_fn, (state.atari_state, action), None, length=self.frame_skip, ) last_env_state = jax.tree.map(lambda x: x[-1], env_states) if self.max_pooling and self.frame_skip > 1: image = self._env.render(last_env_state) prev_env_state = jax.tree.map(lambda x: x[-2], env_states) prev_image = self._env.render(prev_env_state) latest_image = jnp.maximum(image, prev_image) else: latest_image = self._env.render(last_env_state) processed_image = preprocess_image(self, latest_image) image_stack = jnp.concatenate([state.image_stack[1:], jnp.expand_dims(processed_image, axis=0)], axis=0) reward = jnp.sum(rewards) if self.clip_reward: reward = jnp.sign(reward) terminated = terminations.any() truncated = truncations.any() # Autoreset (gym's SAME_STEP mode) -> reset whole stack image_stack, pixel_state = jax.lax.cond( jnp.logical_or(infos["env_done"].any(), truncated), # use actual env_done for reset condition, not affected by episodic life lambda: self.reset(atari_state.key), lambda: (image_stack, PixelState(atari_state, image_stack)) ) def reduce_info(k, v): if k in ["env_reward", "all_rewards"]: return jnp.sum(v, axis=0) if k == "env_done": return jnp.any(v) return v[-1] info_dict = {k: reduce_info(k, v) for k, v in infos.items()} return image_stack, pixel_state, reward, terminated, truncated, info_dict
[docs] @struct.dataclass class PixelAndObjectCentricState: atari_state: AtariState image_stack: jax.Array obs_stack: Any
[docs] class PixelAndObjectCentricWrapper(JaxatariWrapper): """ Wrapper for Atari environments that returns the flattened pixel observations and object-centric observations. Apply this wrapper after the AtariWrapper! """ def __init__(self, env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False, use_native_downscaling: bool = False, smooth_image: bool = False, frame_stack_size: int = 4, frame_skip: int = 4, max_pooling: bool = True, clip_reward: bool = True): super().__init__(env) assert isinstance(env, AtariWrapper), "PixelAndObjectCentricWrapper must be applied after AtariWrapper" self.frame_stack_size = frame_stack_size self.frame_skip = frame_skip self.max_pooling = max_pooling self.clip_reward = clip_reward self.smooth_image = smooth_image self.native_downscaling = False # Access the Base Environment base_env = self._env._env if isinstance(self._env, AtariWrapper) else self._env if do_pixel_resize and use_native_downscaling: # call helper from modifications to make sure that applied mods remain applied after native downscaling (lazy import to avoid circular dependency) from jaxatari.modification import apply_native_downscaling self.do_pixel_resize, self.grayscale = apply_native_downscaling( base_env, pixel_resize_shape, grayscale ) self.pixel_resize_shape = pixel_resize_shape self.native_downscaling = True else: self.do_pixel_resize = do_pixel_resize self.pixel_resize_shape = pixel_resize_shape self.grayscale = grayscale # Part 1: Define the stacked image space. # If we hot-swapped, image_space() will now return the correct small size automatically final_shape = self._env.image_space().shape # If we are doing wrapper-side resizing (legacy), we still calculate manually if self.do_pixel_resize: height, width = self.pixel_resize_shape channels = 1 if self.grayscale else final_shape[2] final_shape = (height, width, channels) image_space = spaces.Box(low=0, high=255, shape=final_shape, dtype=jnp.uint8) stacked_image_space = spaces.stack_space(image_space, self.frame_stack_size) # Part 2: Define the FLATTENED object space with exact bounds. single_frame_space = self._env.observation_space() lows, highs = [], [] for leaf_space in jax.tree.leaves(single_frame_space): if isinstance(leaf_space, spaces.Box): low_arr = np.broadcast_to(leaf_space.low, leaf_space.shape).flatten() high_arr = np.broadcast_to(leaf_space.high, leaf_space.shape).flatten() lows.append(low_arr) highs.append(high_arr) elif isinstance(leaf_space, spaces.Discrete): lows.append(np.array([0], dtype=leaf_space.dtype)) highs.append(np.array([leaf_space.n - 1], dtype=leaf_space.dtype)) else: raise TypeError(f"Unsupported space type for flattening: {type(leaf_space)}") if not lows: raise ValueError("The observation space appears to be empty or contain unsupported types.") single_frame_lows = np.concatenate(lows) single_frame_highs = np.concatenate(highs) stacked_object_space_flat = spaces.Box( low=single_frame_lows, high=single_frame_highs, shape=(self.frame_stack_size, int(single_frame_lows.shape[0])), dtype=jnp.float32 ) # Part 3: Combine them into the final Tuple space. self._observation_space = spaces.Tuple(( stacked_image_space, stacked_object_space_flat ))
[docs] def observation_space(self) -> spaces.Tuple: """Returns a Tuple space containing stacked image and object spaces.""" return self._observation_space
@functools.partial(jax.jit, static_argnums=(0,)) def _flatten_single_obs(self, obs): """Flatten a single object-centric observation using ravel_pytree.""" return flatten_util.ravel_pytree(obs)[0].astype(jnp.float32)
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset( self, key: chex.PRNGKey ) -> Tuple[Tuple[chex.Array, chex.Array], PixelAndObjectCentricState]: obs, atari_state = self._env.reset(key) flat_obs = self._flatten_single_obs(obs) obs_stack = jnp.stack([flat_obs] * self.frame_stack_size) image = self._env.render(atari_state.env_state) processed_image = preprocess_image(self, image) image_stack = jnp.stack([processed_image] * self.frame_stack_size) new_state = PixelAndObjectCentricState(atari_state, image_stack, obs_stack) return (image_stack, obs_stack), new_state
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: PixelAndObjectCentricState, action: int, ) -> Tuple[Tuple[chex.Array, chex.Array], PixelAndObjectCentricState, float, bool, bool, Dict[Any, Any]]: def body_fn(carry, _): atari_state, action = carry obs, new_atari_state, reward, terminated, truncated, info = self._env.step(atari_state, action) return (new_atari_state, action), (obs, new_atari_state.env_state, reward, terminated, truncated, info) (atari_state, _), (obs, env_states, rewards, terminations, truncations, infos) = jax.lax.scan( body_fn, (state.atari_state, action), None, length=self.frame_skip, ) latest_obs = jax.tree.map(lambda x: x[-1], obs) flat_latest_obs = self._flatten_single_obs(latest_obs) obs_stack = jnp.concatenate([state.obs_stack[1:], jnp.expand_dims(flat_latest_obs, axis=0)], axis=0) last_env_state = jax.tree.map(lambda x: x[-1], env_states) if self.max_pooling and self.frame_skip > 1: image = self._env.render(last_env_state) prev_env_state = jax.tree.map(lambda x: x[-2], env_states) prev_image = self._env.render(prev_env_state) latest_image = jnp.maximum(image, prev_image) else: latest_image = self._env.render(last_env_state) processed_image = preprocess_image(self, latest_image) image_stack = jnp.concatenate([state.image_stack[1:], jnp.expand_dims(processed_image, axis=0)], axis=0) reward = jnp.sum(rewards) if self.clip_reward: reward = jnp.sign(reward) terminated = terminations.any() truncated = truncations.any() # Autoreset (gym's SAME_STEP mode) -> reset whole stack (image_stack, obs_stack), pixel_oc_state = jax.lax.cond( jnp.logical_or(infos["env_done"].any(), truncated), # use actual env_done for reset condition, not affected by episodic life lambda: self.reset(atari_state.key), lambda: ((image_stack, obs_stack), PixelAndObjectCentricState(atari_state, image_stack, obs_stack)) ) def reduce_info(k, v): if k in ["env_reward", "all_rewards"]: return jnp.sum(v, axis=0) if k == "env_done": return jnp.any(v) return v[-1] info_dict = {k: reduce_info(k, v) for k, v in infos.items()} return (image_stack, obs_stack), pixel_oc_state, reward, terminated, truncated, info_dict
[docs] class PixelAndObjectObsWrapper(PixelAndObjectCentricWrapper): """ Exactly the same as PixelAndObjectCentricWrapper, but return structured OC-obs instead of flattened array. """
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset( self, key: chex.PRNGKey ) -> Tuple[Tuple[chex.Array, Any], PixelAndObjectCentricState]: obs, atari_state = self._env.reset(key) image = self._env.render(atari_state.env_state) processed_image = preprocess_image(self, image) image_stack = jnp.stack([processed_image] * self.frame_stack_size) obs_stack = jax.tree.map( lambda leaf: jnp.stack([leaf] * self.frame_stack_size), obs, ) new_state = PixelAndObjectCentricState(atari_state, image_stack, obs_stack) return (image_stack, obs_stack), new_state
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: PixelAndObjectCentricState, action: int, ) -> Tuple[Tuple[chex.Array, chex.Array], PixelAndObjectCentricState, float, bool, bool, Dict[Any, Any]]: def body_fn(carry, _): atari_state, action = carry obs, new_atari_state, reward, terminated, truncated, info = self._env.step(atari_state, action) return (new_atari_state, action), (obs, new_atari_state.env_state, reward, terminated, truncated, info) (atari_state, _), (obs, env_states, rewards, terminations, truncations, infos) = jax.lax.scan( body_fn, (state.atari_state, action), None, length=self.frame_skip, ) latest_obs = jax.tree.map(lambda x: x[-1], obs) obs_stack = jax.tree.map( lambda stack_leaf, obs_leaf: jnp.concatenate([stack_leaf[1:], jnp.expand_dims(obs_leaf, axis=0)], axis=0), state.obs_stack, latest_obs, ) last_env_state = jax.tree.map(lambda x: x[-1], env_states) if self.max_pooling and self.frame_skip > 1: image = self._env.render(last_env_state) prev_env_state = jax.tree.map(lambda x: x[-2], env_states) prev_image = self._env.render(prev_env_state) latest_image = jnp.maximum(image, prev_image) else: latest_image = self._env.render(last_env_state) processed_image = preprocess_image(self, latest_image) image_stack = jnp.concatenate([state.image_stack[1:], jnp.expand_dims(processed_image, axis=0)], axis=0) reward = jnp.sum(rewards) terminated = terminations.any() truncated = truncations.any() # Autoreset (gym's SAME_STEP mode) -> reset whole stack (image_stack, obs_stack), pixel_oc_state = jax.lax.cond( jnp.logical_or(infos["env_done"].any(), truncated), # use actual env_done for reset condition, not affected by episodic life lambda: self.reset(atari_state.key), lambda: ((image_stack, obs_stack), PixelAndObjectCentricState(atari_state, image_stack, obs_stack)), ) def reduce_info(k, v): if k in ["env_reward", "all_rewards"]: return jnp.sum(v, axis=0) if k == "env_done": return jnp.any(v) return v[-1] info_dict = {k: reduce_info(k, v) for k, v in infos.items()} new_state = PixelAndObjectCentricState(atari_state, image_stack, obs_stack) return (image_stack, obs_stack), pixel_oc_state, reward, terminated, truncated, info_dict
[docs] class FlattenObservationWrapper(JaxatariWrapper): """ A wrapper that flattens each leaf array in an observation Pytree. Compatible with all the other wrappers, flattens the observations whilst preserving the overarching structure (i.e. if the observation is a tuple of multiple observations, the flattened observation will be a tuple of flattened observations). """ def __init__(self, env): super().__init__(env) # build the new (flattened) observation space original_space = self._env.observation_space() def flatten_space(space: spaces.Box) -> spaces.Box: # Create flattened low/high arrays by broadcasting the original bounds # and then reshaping. This preserves the bounds for each element. flat_low = np.broadcast_to(space.low, space.shape).flatten() flat_high = np.broadcast_to(space.high, space.shape).flatten() return spaces.Box( low=jnp.array(flat_low), high=jnp.array(flat_high), dtype=space.dtype ) self._observation_space = jax.tree.map( flatten_space, original_space, is_leaf=lambda x: isinstance(x, spaces.Box) )
[docs] def observation_space(self) -> spaces.Space: """Returns a space where each leaf array is flattened.""" return self._observation_space
def _process_obs(self, obs_tree: chex.ArrayTree) -> chex.ArrayTree: """Applies .flatten() to each leaf array in the pytree.""" def flatten_and_cast(leaf): flattened = leaf.flatten() # Cast to float32 to match space dtype return flattened.astype(jnp.float32) if isinstance(leaf, jnp.ndarray) else flattened return jax.tree.map(flatten_and_cast, obs_tree)
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset(self, key: chex.PRNGKey) -> Tuple[chex.ArrayTree, Any]: obs, state = self._env.reset(key) processed_obs = self._process_obs(obs) return processed_obs, state # State can be passed through directly
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: Any, action: int, ) -> Tuple[chex.ArrayTree, Any, float, bool, bool, Dict[str, Any]]: obs, next_state, reward, terminated, truncated, info = self._env.step(state, action) processed_obs = self._process_obs(obs) return processed_obs, next_state, reward, terminated, truncated, info
[docs] class NormalizeObservationWrapper(JaxatariWrapper): """ A wrapper that normalizes each leaf in an observation Pytree. This wrapper is compatible with any observation structure (Pytrees). """ def __init__(self, env, to_neg_one: bool = False, dtype=jnp.float16): super().__init__(env) self._to_neg_one = to_neg_one self._dtype = dtype original_space = self._env.observation_space() # Create Pytrees of the same structure as observations, but holding the low/high bounds. self._low = jax.tree.map( lambda s: jnp.array(s.low, dtype=s.dtype), original_space, is_leaf=lambda x: isinstance(x, spaces.Box) ) self._high = jax.tree.map( lambda s: jnp.array(s.high, dtype=s.dtype), original_space, is_leaf=lambda x: isinstance(x, spaces.Box) ) # The new observation space will have the same structure, but all leaves def _normalize_space(space: spaces.Box) -> spaces.Box: low_val = -1.0 if self._to_neg_one else 0.0 return spaces.Box( low=low_val, high=1.0, shape=space.shape, dtype=self._dtype ) self._observation_space = jax.tree.map( _normalize_space, original_space, is_leaf=lambda x: isinstance(x, spaces.Box) )
[docs] def observation_space(self) -> spaces.Space: """Returns the normalized observation space where leaves are in [0, 1].""" return self._observation_space
def _normalize_leaf(self, obs_leaf, low_leaf, high_leaf): """Helper function to normalize a single leaf array.""" obs_leaf = obs_leaf.astype(self._dtype) # Calculate the range and scale for normalization range_leaf = high_leaf.astype(self._dtype) - low_leaf.astype(self._dtype) scale = 1.0 / jnp.where(range_leaf > 1e-8, range_leaf, 1.0) # Normalize to [0, 1] normalized_0_1 = (obs_leaf - low_leaf.astype(self._dtype)) * scale # Conditionally shift to [-1, 1] final_normalized = jax.lax.cond( self._to_neg_one, lambda x: 2.0 * x - 1.0, lambda x: x, normalized_0_1 ) # Clip to ensure values are within the target range clip_low = -1.0 if self._to_neg_one else 0.0 return jnp.clip(final_normalized, clip_low, 1.0) def _normalize_obs(self, obs: chex.ArrayTree) -> chex.ArrayTree: """ Applies normalization to each leaf array in the observation pytree, robustly handling structural mismatches between observation and space Pytrees. """ # Get the leaves of all pytrees. Since the number of leaves and their # order is guaranteed to be the same, we can work with the flat lists. obs_leaves = jax.tree.leaves(obs) low_leaves = jax.tree.leaves(self._low) high_leaves = jax.tree.leaves(self._high) # Apply the normalization to each corresponding leaf triplet. normalized_leaves = [ self._normalize_leaf(o, l, h) for o, l, h in zip(obs_leaves, low_leaves, high_leaves) ] # Reconstruct the output pytree with the same structure as the input 'obs'. return jax.tree.unflatten(jax.tree.structure(obs), normalized_leaves)
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset(self, key: chex.PRNGKey) -> Tuple[chex.ArrayTree, Any]: obs, state = self._env.reset(key) normalized_obs = self._normalize_obs(obs) return normalized_obs, state
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: Any, action: int, ) -> Tuple[chex.ArrayTree, Any, float, bool, bool, Dict[str, Any]]: obs, next_state, reward, terminated, truncated, info = self._env.step(state, action) normalized_obs = self._normalize_obs(obs) return normalized_obs, next_state, reward, terminated, truncated, info
[docs] @struct.dataclass class LogState: atari_state: Any # Can be any of the states from wrappers above episode_returns: float episode_lengths: int returned_episode_returns: float returned_episode_lengths: int
[docs] class LogWrapper(JaxatariWrapper): """Log episode returns and lengths. An episode ends when the wrapped env returns done=True. Uses env_reward from info when present (unclipped); otherwise uses the step reward. """
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset( self, key: chex.PRNGKey ) -> Tuple[chex.Array, LogState]: obs, atari_state = self._env.reset(key) state = LogState(atari_state, 0.0, 0, 0.0, 0) return obs, state
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: LogState, action: int, ) -> Tuple[chex.Array, LogState, float, bool, bool, Dict[Any, Any]]: obs, atari_state, reward, terminated, truncated, info = self._env.step(state.atari_state, action) # use env_reward (unclipped/unchanged) for logging when available new_episode_return = state.episode_returns + info.get("env_reward", reward) new_episode_length = state.episode_lengths + 1 # use env_done for logging when available (e.g. to ignore episodic_life) done = info.get("env_done", jnp.bool_(terminated)) done = jnp.logical_or(done, truncated) # truncated episodes are considered done for logging purposes state = LogState( atari_state=atari_state, episode_returns=jnp.where(done, jnp.float32(0), jnp.float32(new_episode_return)), episode_lengths=jnp.where(done, jnp.int32(0), jnp.int32(new_episode_length)), returned_episode_returns=jnp.where( done, jnp.float32(new_episode_return), jnp.float32(state.returned_episode_returns) ), returned_episode_lengths=jnp.where( done, jnp.int32(new_episode_length), jnp.int32(state.returned_episode_lengths) ), ) info["returned_episode_returns"] = state.returned_episode_returns info["returned_episode_lengths"] = state.returned_episode_lengths info["returned_episode"] = done # Still need to return the actual/wrapped termination signal (affected by episodic life) return obs, state, reward, terminated, truncated, info
[docs] @struct.dataclass class MultiRewardLogState: atari_state: Any # Can be any of the states from wrappers above episode_returns_env: float episode_returns: chex.Array episode_lengths: int returned_episode_returns_env: float returned_episode_returns: chex.Array returned_episode_lengths: int
[docs] class MultiRewardLogWrapper(JaxatariWrapper): """Log episode returns and lengths for multiple rewards. An episode ends when the wrapped env returns done=True. Apply MultiRewardWrapper to the core env when using this wrapper. Final logs: 'returned_episode_returns_0', ... for each reward function; env reward in 'returned_episode_env_returns'. """
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def reset( self, key: chex.PRNGKey, ) -> Tuple[chex.Array, MultiRewardLogState]: obs, atari_state = self._env.reset(key) # Dummy step to get info structure _, _, _, _, _, dummy_info = self._env.step(atari_state, 0) rewards_shape_provider = dummy_info.get("all_rewards", jnp.zeros(1)) episode_returns_init = jnp.zeros_like(rewards_shape_provider) state = MultiRewardLogState(atari_state, 0.0, episode_returns_init, 0, 0.0, episode_returns_init, 0) return obs, state
[docs] @functools.partial(jax.jit, static_argnums=(0,)) def step( self, state: MultiRewardLogState, action: int, ) -> Tuple[chex.Array, MultiRewardLogState, float, bool, bool, Dict[Any, Any]]: obs, atari_state, reward, terminated, truncated, info = self._env.step(state.atari_state, action) new_episode_return_env = state.episode_returns_env + info.get("env_reward", reward) all_rewards_step = info.get("all_rewards", jnp.zeros_like(state.episode_returns)) new_episode_return = state.episode_returns + all_rewards_step new_episode_length = state.episode_lengths + 1 done_ = info.get("env_done", jnp.bool_(terminated)) state = MultiRewardLogState( atari_state=atari_state, episode_returns_env=jnp.where(done_, jnp.float32(0), jnp.float32(new_episode_return_env)), episode_returns=jnp.where(done_, jnp.zeros_like(state.episode_returns), new_episode_return), episode_lengths=jnp.where(done_, jnp.int32(0), jnp.int32(new_episode_length)), returned_episode_returns_env=jnp.where( done_, jnp.float32(new_episode_return_env), jnp.float32(state.returned_episode_returns_env) ), returned_episode_returns=jnp.where( done_, new_episode_return, state.returned_episode_returns ), returned_episode_lengths=jnp.where( done_, jnp.int32(new_episode_length), jnp.int32(state.returned_episode_lengths) ), ) info["returned_episode_env_returns"] = state.returned_episode_returns_env for i, r in enumerate(new_episode_return): info[f"returned_episode_returns_{i}"] = state.returned_episode_returns[i] info["returned_episode_lengths"] = state.returned_episode_lengths info["returned_episode"] = done_ return obs, state, reward, terminated, truncated, info