"""Jaxatari Wrappers"""
import functools
from typing import Any, Dict, Tuple, Union, Optional, Callable
from dataclasses import is_dataclass, asdict
import chex
from flax import struct
import jax
import jax.image as jim
import jax.numpy as jnp
from jaxatari.environment import EnvState, JAXAtariAction as Action
import jaxatari.spaces as spaces
import numpy as np
[docs]
class JaxatariWrapper(object):
"""Base class for JAXAtark wrappers."""
def __init__(self, env):
self._env = env
# provide proxy access to regular attributes of wrapped object
def __getattr__(self, name):
return getattr(self._env, name)
[docs]
class MultiRewardWrapper(JaxatariWrapper):
"""
Allows providing multiple reward functions to be computed at every step.
Apply this wrapper directly after the base environment, before any other wrappers.
"""
def __init__(self, env, reward_funcs: list[Callable]):
super().__init__(env)
assert isinstance(reward_funcs, list) and len(reward_funcs) > 0, "reward_funcs must be a non-empty list of callables"
self._reward_funcs = reward_funcs
@functools.partial(jax.jit, static_argnums=(0,))
def _get_all_rewards(self, previous_state: EnvState, state: EnvState) -> chex.Array:
"""Compute multiple rewards based on the provided reward functions."""
if self._reward_funcs is None:
return jnp.zeros(1)
rewards = jnp.array(
[reward_func(previous_state, state) for reward_func in self._reward_funcs]
)
return rewards
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(self, state: EnvState, action: int) -> Tuple[chex.Array, EnvState, float, bool, Dict]:
obs, new_state, reward, done, info = self._env.step(state, action)
all_rewards = self._get_all_rewards(state, new_state)
# Convert info to dict: handle NamedTuple (has _asdict) or dataclass (use asdict)
if hasattr(info, '_asdict'):
info = info._asdict()
elif is_dataclass(info):
info = asdict(info)
info["all_rewards"] = all_rewards
return obs, new_state, reward, done, info
[docs]
@struct.dataclass
class AtariState:
env_state: EnvState
key: chex.PRNGKey
step: int
prev_action: int
obs_stack: chex.Array
[docs]
class AtariWrapper(JaxatariWrapper):
"""
Wrapper for Atari environments that returns the rendered image and object-centric observations unflattened.
Both are stacked by frame_stack_size.
Args:
env: The environment to wrap.
sticky_actions: Whether to use sticky actions.
frame_stack_size: The number of frames to stack.
frame_skip: The number of frames to skip.
"""
# TODO: change sticky_actions to float
def __init__(self, env, sticky_actions: bool = True, frame_stack_size: int = 4, frame_skip: int = 4, max_episode_length: int = 10_000, episodic_life: bool = True, first_fire: bool = True, noop_reset: int = 0, clip_reward: bool = False, max_pooling: bool = False, full_action_space: bool = False,):
super().__init__(env)
self._env = env
self.sticky_actions = sticky_actions
self.frame_stack_size = frame_stack_size
self.frame_skip = frame_skip
self.max_episode_length = max_episode_length
self.episodic_life = episodic_life
self.first_fire = first_fire
self.noop_reset = False if noop_reset == 0 else True
self.noop_max = noop_reset
self.clip_reward = clip_reward
self.max_pooling = max_pooling
self.full_action_space = full_action_space
# --- 1) HANDLE FULL ACTION SPACE LOGIC ---
# If requested, swap the environment's (minimal) action set for the full identity set.
# This keeps each game env "clean" while enabling a central switch for experimentation.
if self.full_action_space and hasattr(self._env, "ACTION_SET"):
# Overwrite the instance attribute with [0, 1, ... 17]
self._env.ACTION_SET = jnp.arange(18, dtype=jnp.int32)
# --- 2) RESOLVE CORRECT 'FIRE' ACTION INDEX ---
# The wrapped env expects an *index* into ACTION_SET (agent action), not the ALE action constant.
self.fire_action_index: int = int(Action.FIRE) # fallback if env doesn't expose ACTION_SET
self.first_fire = first_fire
if hasattr(self._env, "ACTION_SET"):
# Convert to numpy for search (safe in __init__)
action_set_np = np.array(self._env.ACTION_SET)
fire_indices = np.where(action_set_np == int(Action.FIRE))[0]
if len(fire_indices) > 0:
self.fire_action_index = int(fire_indices[0])
else:
# Game has no FIRE action (e.g. Freeway).
# Disable first_fire to prevent sending a random command by mistake.
self.first_fire = False
self._observation_space = spaces.stack_space(self._env.observation_space(), self.frame_stack_size)
[docs]
def observation_space(self) -> spaces.Space:
"""Returns the stacked observation space."""
return self._observation_space
[docs]
def image_space(self) -> spaces.Box:
"""Returns the image space."""
return self._env.image_space()
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, EnvState]:
# Split keys for all potential random operations
env_key, wrapper_key, noop_key = jax.random.split(key, 3)
obs, env_state = self._env.reset(env_key)
step = jnp.array(0, dtype=jnp.int32)
prev_action = jnp.array(0, dtype=jnp.int32)
# TODO: in which order should the noop and first_fire be done?
# ========== NOOP RESET ==========
def perform_noop_reset(carry):
# This function will be executed if self.noop_reset is True
env_state, obs, step = carry
# Generate the random number of no-op steps to take.
num_noops = jax.random.randint(noop_key, shape=(), minval=0, maxval=self.noop_max + 1)
def noop_body_fn(i, loop_carry):
current_env_state, current_obs = loop_carry
# We always compute the next step for static graph tracing...
next_obs, next_env_state, _, _, _ = self._env.step(current_env_state, Action.NOOP)
# ...but only apply the update if the loop index is less than our dynamic random number.
env_state_out = jax.lax.cond(i < num_noops, lambda: next_env_state, lambda: current_env_state)
obs_out = jax.lax.cond(i < num_noops, lambda: next_obs, lambda: current_obs)
return env_state_out, obs_out
# Loop for the static maximum number of no-ops.
final_env_state, final_obs = jax.lax.fori_loop(0, self.noop_max, noop_body_fn, (env_state, obs))
# Update the step counter by the dynamic number of no-ops performed.
final_step = step + num_noops
return final_env_state, final_obs, final_step
# Use lax.cond to conditionally apply the whole no-op block based on the static self.noop_reset flag.
env_state, obs, step = jax.lax.cond(
self.noop_reset,
lambda carry: perform_noop_reset(carry),
lambda carry: carry,
(env_state, obs, step)
)
# ========== FIRST FIRE ==========
def perform_first_fire(carry):
env_state, obs, step, _ = carry
fire_obs, fire_env_state, _, _, _ = self._env.step(env_state, self.fire_action_index)
return fire_env_state, fire_obs, step + 1, self.fire_action_index
def identity_fire(carry):
return carry
# Conditionally apply the fire action based on the static self.first_fire flag.
env_state, obs, step, prev_action = jax.lax.cond(
self.first_fire,
perform_first_fire,
identity_fire,
(env_state, obs, step, prev_action)
)
# Create the initial frame stack from the final observation.
obs = jax.tree.map(lambda x: jnp.stack([x] * self.frame_stack_size), obs)
return obs, AtariState(env_state, wrapper_key, step, prev_action, obs)
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(self, state: AtariState, action: Union[int, float]) -> Tuple[Tuple[chex.Array, chex.Array], AtariState, float, bool, Dict[Any, Any]]:
step_key, next_state_key = jax.random.split(state.key)
new_action = action
# Use lax.cond and fix shape for scalar actions
use_sticky_action = jax.random.uniform(step_key, shape=()) < 0.25
new_action = jax.lax.cond(self.sticky_actions & use_sticky_action, lambda: state.prev_action, lambda: action)
# use scan to step the env for frame_skip times
def body_fn(carry, _):
env_state, action = carry
obs, new_env_state, reward, done, info = self._env.step(env_state, action)
return (new_env_state, action), (obs, reward, done, info)
(new_env_state, new_action), (obs, rewards, dones, infos) = jax.lax.scan(
body_fn,
(state.env_state, new_action),
None,
length=self.frame_skip,
)
# ========== MAX POOLING LOGIC ==========
def do_max_pool(obs_pytree):
# Take the element-wise maximum over the last two frames.
last_obs = jax.tree.map(lambda x: x[-1], obs_pytree)
second_last_obs = jax.tree.map(lambda x: x[-2], obs_pytree)
return jax.tree.map(jnp.maximum, last_obs, second_last_obs)
def take_last_frame(obs_pytree):
# Default behavior: just take the final frame.
return jax.tree.map(lambda x: x[-1], obs_pytree)
# Conditionally apply max-pooling based on the static flag.
latest_obs = jax.lax.cond(self.max_pooling, do_max_pool, take_last_frame, obs)
# push latest obs into the stack
new_obs_stack = jax.tree.map(lambda stack, obs_leaf: jnp.concatenate([stack[1:], jnp.expand_dims(obs_leaf, axis=0)], axis=0), state.obs_stack, latest_obs)
reward = jnp.sum(rewards)
done = jnp.logical_or(dones.any(), state.step >= self.max_episode_length)
if self.episodic_life:
# If the player has lost a life, we consider the episode done
if hasattr(state.env_state, "lives"):
done = jnp.logical_or(done, new_env_state.lives < state.env_state.lives)
elif hasattr(state.env_state, "lives_lost"):
done = jnp.logical_or(done, new_env_state.lives_lost > state.env_state.lives_lost)
def reduce_info(k, v):
if k == "all_rewards":
return v.sum(axis=0)
else:
return v[-1]
if hasattr(infos, '_asdict'):
# It's a namedtuple or similar, convert to dict
info_items = infos._asdict().items()
else:
# It's already a dict
info_items = infos.items()
info_dict = {k: reduce_info(k, v) for k, v in info_items}
# Use jax.lax.cond to correctly handle state and key propagation on reset
def _reset_fn(_):
# When done, reset. The new state will contain the properly advanced next_state_key.
return self.reset(next_state_key)
def _step_fn(_):
# When not done, create the next state, passing next_state_key for the *next* step.
next_state = AtariState(new_env_state, next_state_key, state.step + 1, new_action, new_obs_stack)
return new_obs_stack, next_state
new_obs, new_state = jax.lax.cond(done, _reset_fn, _step_fn, operand=None)
reward = jax.lax.cond(
self.clip_reward,
lambda reward: jnp.sign(reward),
lambda reward: reward,
reward
)
return new_obs, new_state, reward, done, info_dict
[docs]
class ObjectCentricWrapper(JaxatariWrapper):
"""
Wrapper for Atari environments that returns stacked object-centric observations.
The output observation is a 2D array of shape (frame_stack_size, num_features).
Apply this wrapper after the AtariWrapper!
"""
def __init__(self, env):
super().__init__(env)
assert isinstance(env, AtariWrapper), "ObjectCentricWrapper must be applied after AtariWrapper"
# First, get the space for a SINGLE, UNSTACKED frame from the base env.
single_frame_space = self._env._env.observation_space()
# Calculate the bounds and size for a single flattened frame for all leaf spaces.
lows, highs = [], []
single_frame_flat_size = 0
for leaf_space in jax.tree.leaves(single_frame_space):
if isinstance(leaf_space, spaces.Box):
# Flatten the bounds arrays for Box spaces
low_arr = np.broadcast_to(leaf_space.low, leaf_space.shape).flatten()
high_arr = np.broadcast_to(leaf_space.high, leaf_space.shape).flatten()
lows.append(low_arr)
highs.append(high_arr)
single_frame_flat_size += low_arr.size
elif isinstance(leaf_space, spaces.Discrete):
# A Discrete space flattens to a single value
lows.append(np.array([0], dtype=leaf_space.dtype))
highs.append(np.array([leaf_space.n - 1], dtype=leaf_space.dtype))
single_frame_flat_size += 1
else:
raise TypeError(f"Unsupported space type for flattening: {type(leaf_space)}")
if not lows:
raise ValueError("The observation space appears to be empty or contain unsupported types.")
single_frame_lows = np.concatenate(lows)
single_frame_highs = np.concatenate(highs)
# create the 2D Box space
self._observation_space = spaces.Box(
low=single_frame_lows,
high=single_frame_highs,
shape=(self._env.frame_stack_size, int(single_frame_flat_size)),
dtype=single_frame_lows.dtype
)
[docs]
def observation_space(self) -> spaces.Box:
"""Returns a Box space for the flattened observation."""
return self._observation_space
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(
self, key: chex.PRNGKey
) -> Tuple[chex.Array, EnvState]:
obs, state = self._env.reset(key)
# Flatten each frame in the stack
flat_obs = jax.vmap(self._env.obs_to_flat_array)(obs)
return flat_obs, state
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: AtariState,
action: Union[int, float],
) -> Tuple[chex.Array, EnvState, float, bool, Any]: # dict]:
obs, state, reward, done, info = self._env.step(state, action)
# Flatten each frame in the stack
flat_obs = jax.vmap(self._env.obs_to_flat_array)(obs)
return flat_obs, state, reward, done, info
[docs]
@struct.dataclass
class PixelState:
atari_state: AtariState
image_stack: chex.Array
[docs]
class PixelObsWrapper(JaxatariWrapper):
"""
Wrapper for Atari environments that returns the flattened pixel observations.
Apply this wrapper after the AtariWrapper!
"""
def __init__(self, env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False):
super().__init__(env)
assert isinstance(env, AtariWrapper), "PixelObsWrapper has to be applied after AtariWrapper"
self.do_pixel_resize = do_pixel_resize
self.pixel_resize_shape = pixel_resize_shape
self.grayscale = grayscale
# Dynamically calculate the final observation space shape
base_shape = self._env.image_space().shape
height, width, channels = base_shape
if self.do_pixel_resize:
height, width = self.pixel_resize_shape
if self.grayscale:
channels = 1
final_shape = (height, width, channels)
# Create the space for a single preprocessed frame
image_space = spaces.Box(low=0, high=255, shape=final_shape, dtype=jnp.uint8)
# Stack the single-frame space
self._observation_space = spaces.stack_space(image_space, self._env.frame_stack_size)
[docs]
def observation_space(self) -> spaces.Box:
"""Returns the stacked image space."""
return self._observation_space
def _preprocess_image(self, image: chex.Array) -> chex.Array:
"""Applies resizing and grayscaling to a single image frame."""
image = image.astype(jnp.float32)
# Has to use a standard Python `if` since jax.lax.cond would fail due to different shapes. This is possible since do_pixel_resize is a static parameter.
if self.do_pixel_resize:
image = jim.resize(image, (self.pixel_resize_shape[0], self.pixel_resize_shape[1], image.shape[-1]), method='bilinear')
# applies grayscale if enabled with the same method as for resize
if self.grayscale:
image = jnp.dot(image, jnp.array([0.2989, 0.5870, 0.1140]))[..., jnp.newaxis] # numbers for grayscale transformation as in https://en.wikipedia.org/wiki/Luma_(video)
return image.astype(jnp.uint8)
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, PixelState]:
# The underlying AtariWrapper returns its own state, which we store.
_, atari_state = self._env.reset(key)
image = self._env.render(atari_state.env_state)
processed_image = self._preprocess_image(image)
# Create a stack of identical processed images for the initial state
image_stack = jnp.stack([processed_image] * self._env.frame_stack_size)
return image_stack, PixelState(atari_state, image_stack)
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: PixelState,
action: Union[int, float],
) -> Tuple[chex.Array, EnvState, float, bool, Any]:
# Pass the nested atari_state to the underlying wrapper's step function
_, atari_state, reward, done, info = self._env.step(state.atari_state, action)
image = self._env.render(atari_state.env_state)
processed_image = self._preprocess_image(image)
# Update the image stack by shifting and adding the new processed image
image_stack = jnp.concatenate([state.image_stack[1:], jnp.expand_dims(processed_image, axis=0)], axis=0)
# Create the new state with the *new* atari_state from the step
new_state = PixelState(atari_state, image_stack)
return image_stack, new_state, reward, done, info
[docs]
@struct.dataclass
class PixelAndObjectCentricState:
atari_state: AtariState
image_stack: chex.Array
obs_stack: chex.Array
[docs]
class PixelAndObjectCentricWrapper(JaxatariWrapper):
"""
Wrapper for Atari environments that returns the flattened pixel observations and object-centric observations.
Apply this wrapper after the AtariWrapper!
"""
def __init__(self, env, do_pixel_resize: bool = False, pixel_resize_shape: tuple[int, int] = (84, 84), grayscale: bool = False):
super().__init__(env)
assert isinstance(env, AtariWrapper), "PixelAndObjectCentricWrapper must be applied after AtariWrapper"
# Part 1: Define the stacked image space.
self.do_pixel_resize = do_pixel_resize
self.pixel_resize_shape = pixel_resize_shape
self.grayscale = grayscale
base_shape = self._env.image_space().shape
height, width, channels = base_shape
if self.do_pixel_resize:
height, width = self.pixel_resize_shape
if self.grayscale:
channels = 1
final_shape = (height, width, channels)
image_space = spaces.Box(low=0, high=255, shape=final_shape, dtype=jnp.uint8)
stacked_image_space = spaces.stack_space(image_space, self._env.frame_stack_size)
# Part 2: Define the FLATTENED object space (with the bug fix).
single_frame_space = self._env._env.observation_space()
lows, highs = [], []
single_frame_flat_size = 0
for leaf_space in jax.tree.leaves(single_frame_space):
if isinstance(leaf_space, spaces.Box):
low_arr = np.broadcast_to(leaf_space.low, leaf_space.shape).flatten()
high_arr = np.broadcast_to(leaf_space.high, leaf_space.shape).flatten()
lows.append(low_arr)
highs.append(high_arr)
single_frame_flat_size += low_arr.size
elif isinstance(leaf_space, spaces.Discrete):
lows.append(np.array([0], dtype=leaf_space.dtype))
highs.append(np.array([leaf_space.n - 1], dtype=leaf_space.dtype))
single_frame_flat_size += 1
else:
raise TypeError(f"Unsupported space type for flattening: {type(leaf_space)}")
if not lows:
raise ValueError("The observation space appears to be empty or contain unsupported types.")
single_frame_lows = np.concatenate(lows)
single_frame_highs = np.concatenate(highs)
stacked_object_space_flat = spaces.Box(
low=single_frame_lows,
high=single_frame_highs,
shape=(self._env.frame_stack_size, int(single_frame_flat_size)),
dtype=single_frame_lows.dtype
)
# Part 3: Combine them into the final Tuple space.
self._observation_space = spaces.Tuple((
stacked_image_space,
stacked_object_space_flat
))
[docs]
def observation_space(self) -> spaces.Tuple:
"""Returns a Tuple space containing stacked image and object spaces."""
return self._observation_space
def _preprocess_image(self, image: chex.Array) -> chex.Array:
"""Applies resizing and grayscaling to a single image frame."""
image = image.astype(jnp.float32)
# Has to use a standard Python `if` since jax.lax.cond would fail due to different shapes. This is possible since do_pixel_resize is a static parameter.
if self.do_pixel_resize:
image = jim.resize(image, (self.pixel_resize_shape[0], self.pixel_resize_shape[1], image.shape[-1]), method='bilinear')
# applies grayscale if enabled with the same method as for resize
if self.grayscale:
image = jnp.dot(image, jnp.array([0.2989, 0.5870, 0.1140]))[..., jnp.newaxis] # numbers for grayscale transformation as in https://en.wikipedia.org/wiki/Luma_(video)
return image.astype(jnp.uint8)
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(
self, key: chex.PRNGKey
) -> Tuple[chex.Array, EnvState]:
# 1. Get the initial object observation stack and state from the AtariWrapper
obs_stack, atari_state = self._env.reset(key)
# 2. Flatten the object-centric part
flat_obs = jax.vmap(self._env.obs_to_flat_array)(obs_stack)
# 3. Render and preprocess the image
image = self._env.render(atari_state.env_state)
processed_image = self._preprocess_image(image)
image_stack = jnp.stack([processed_image] * self._env.frame_stack_size)
# 4. Create the state and observation tuple
new_state = PixelAndObjectCentricState(atari_state, image_stack, flat_obs)
return (image_stack, flat_obs), new_state
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: PixelAndObjectCentricState,
action: Union[int, float],
) -> Tuple[chex.Array, EnvState, float, bool, Any]:
# 1. Step the underlying environment using its state
obs_stack, atari_state, reward, done, info = self._env.step(state.atari_state, action)
# 2. Flatten the new object-centric observation stack
flat_obs = jax.vmap(self._env.obs_to_flat_array)(obs_stack)
# 3. Render and preprocess the new image
image = self._env.render(atari_state.env_state)
processed_image = self._preprocess_image(image)
# 4. Update the image stack with the new processed image
image_stack = jnp.concatenate([state.image_stack[1:], jnp.expand_dims(processed_image, axis=0)], axis=0)
# 5. Create the new state with the new atari_state
new_state = PixelAndObjectCentricState(atari_state, image_stack, flat_obs)
return (image_stack, flat_obs), new_state, reward, done, info
[docs]
class PixelAndObjectObsWrapper(PixelAndObjectCentricWrapper):
"""
Exactly the same as PixelAndObjectCentricWrapper, but return structured OC-obs instead of flattened array.
"""
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(
self, key: chex.PRNGKey
) -> Tuple[chex.Array, EnvState]:
# 1. Get the initial object observation stack and state from the AtariWrapper
obs_stack, atari_state = self._env.reset(key)
# 3. Render and preprocess the image
image = self._env.render(atari_state.env_state)
processed_image = self._preprocess_image(image)
image_stack = jnp.stack([processed_image] * self._env.frame_stack_size)
# 4. Create the state and observation tuple
new_state = PixelAndObjectCentricState(atari_state, image_stack, obs_stack)
return (image_stack, obs_stack), new_state
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: PixelAndObjectCentricState,
action: Union[int, float],
) -> Tuple[chex.Array, EnvState, float, bool, Any]:
# 1. Step the underlying environment using its state
obs_stack, atari_state, reward, done, info = self._env.step(state.atari_state, action)
# 3. Render and preprocess the new image
image = self._env.render(atari_state.env_state)
processed_image = self._preprocess_image(image)
# 4. Update the image stack with the new processed image
image_stack = jnp.concatenate([state.image_stack[1:], jnp.expand_dims(processed_image, axis=0)], axis=0)
# 5. Create the new state with the new atari_state
new_state = PixelAndObjectCentricState(atari_state, image_stack, obs_stack)
return (image_stack, obs_stack), new_state, reward, done, info
[docs]
class FlattenObservationWrapper(JaxatariWrapper):
"""
A wrapper that flattens each leaf array in an observation Pytree.
Compatible with all the other wrappers, flattens the observations whilst preserving the overarching structure (i.e. if the observation is a tuple of multiple observations, the flattened observation will be a tuple of flattened observations).
"""
def __init__(self, env):
super().__init__(env)
# build the new (flattened) observation space
original_space = self._env.observation_space()
def flatten_space(space: spaces.Box) -> spaces.Box:
# Create flattened low/high arrays by broadcasting the original bounds
# and then reshaping. This preserves the bounds for each element.
flat_low = np.broadcast_to(space.low, space.shape).flatten()
flat_high = np.broadcast_to(space.high, space.shape).flatten()
return spaces.Box(
low=jnp.array(flat_low),
high=jnp.array(flat_high),
dtype=space.dtype
)
self._observation_space = jax.tree.map(
flatten_space,
original_space,
is_leaf=lambda x: isinstance(x, spaces.Box)
)
[docs]
def observation_space(self) -> spaces.Space:
"""Returns a space where each leaf array is flattened."""
return self._observation_space
def _process_obs(self, obs_tree: chex.ArrayTree) -> chex.ArrayTree:
"""Applies .flatten() to each leaf array in the pytree."""
return jax.tree.map(lambda leaf: leaf.flatten(), obs_tree)
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.ArrayTree, Any]:
obs, state = self._env.reset(key)
processed_obs = self._process_obs(obs)
return processed_obs, state # State can be passed through directly
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: Any,
action: Union[int, float],
) -> Tuple[chex.ArrayTree, Any, float, bool, Dict[str, Any]]:
obs, next_state, reward, done, info = self._env.step(state, action)
processed_obs = self._process_obs(obs)
return processed_obs, next_state, reward, done, info
[docs]
class NormalizeObservationWrapper(JaxatariWrapper):
"""
A wrapper that normalizes each leaf in an observation Pytree.
This wrapper is compatible with any observation structure (Pytrees).
"""
def __init__(self, env, to_neg_one: bool = False, dtype=jnp.float16):
super().__init__(env)
self._to_neg_one = to_neg_one
self._dtype = dtype
original_space = self._env.observation_space()
# Create Pytrees of the same structure as observations, but holding the low/high bounds.
self._low = jax.tree.map(
lambda s: jnp.array(s.low, dtype=s.dtype),
original_space,
is_leaf=lambda x: isinstance(x, spaces.Box)
)
self._high = jax.tree.map(
lambda s: jnp.array(s.high, dtype=s.dtype),
original_space,
is_leaf=lambda x: isinstance(x, spaces.Box)
)
# The new observation space will have the same structure, but all leaves
def _normalize_space(space: spaces.Box) -> spaces.Box:
low_val = -1.0 if self._to_neg_one else 0.0
return spaces.Box(
low=low_val,
high=1.0,
shape=space.shape,
dtype=self._dtype
)
self._observation_space = jax.tree.map(
_normalize_space,
original_space,
is_leaf=lambda x: isinstance(x, spaces.Box)
)
[docs]
def observation_space(self) -> spaces.Space:
"""Returns the normalized observation space where leaves are in [0, 1]."""
return self._observation_space
def _normalize_leaf(self, obs_leaf, low_leaf, high_leaf):
"""Helper function to normalize a single leaf array."""
obs_leaf = obs_leaf.astype(self._dtype)
# Calculate the range and scale for normalization
range_leaf = high_leaf.astype(self._dtype) - low_leaf.astype(self._dtype)
scale = 1.0 / jnp.where(range_leaf > 1e-8, range_leaf, 1.0)
# Normalize to [0, 1]
normalized_0_1 = (obs_leaf - low_leaf.astype(self._dtype)) * scale
# Conditionally shift to [-1, 1]
final_normalized = jax.lax.cond(
self._to_neg_one,
lambda x: 2.0 * x - 1.0,
lambda x: x,
normalized_0_1
)
# Clip to ensure values are within the target range
clip_low = -1.0 if self._to_neg_one else 0.0
return jnp.clip(final_normalized, clip_low, 1.0)
def _normalize_obs(self, obs: chex.ArrayTree) -> chex.ArrayTree:
"""
Applies normalization to each leaf array in the observation pytree,
robustly handling structural mismatches between observation and space Pytrees.
"""
# Get the leaves of all pytrees. Since the number of leaves and their
# order is guaranteed to be the same, we can work with the flat lists.
obs_leaves = jax.tree.leaves(obs)
low_leaves = jax.tree.leaves(self._low)
high_leaves = jax.tree.leaves(self._high)
# Apply the normalization to each corresponding leaf triplet.
normalized_leaves = [
self._normalize_leaf(o, l, h)
for o, l, h in zip(obs_leaves, low_leaves, high_leaves)
]
# Reconstruct the output pytree with the same structure as the input 'obs'.
return jax.tree.unflatten(jax.tree.structure(obs), normalized_leaves)
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.ArrayTree, Any]:
obs, state = self._env.reset(key)
normalized_obs = self._normalize_obs(obs)
return normalized_obs, state
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: Any,
action: Union[int, float],
) -> Tuple[chex.ArrayTree, Any, float, bool, Dict[str, Any]]:
obs, next_state, reward, done, info = self._env.step(state, action)
normalized_obs = self._normalize_obs(obs)
return normalized_obs, next_state, reward, done, info
[docs]
@struct.dataclass
class LogState:
atari_state: Any # Can be any of the states from wrappers above
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
[docs]
class LogWrapper(JaxatariWrapper):
"""Log the episode returns and lengths."""
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(
self, key: chex.PRNGKey
) -> Tuple[chex.Array, LogState]:
obs, atari_state = self._env.reset(key)
state = LogState(atari_state, 0.0, 0, 0.0, 0)
return obs, state
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: LogState,
action: Union[int, float],
) -> Tuple[chex.Array, LogState, float, bool, Dict[Any, Any]]:
obs, atari_state, reward, done, info = self._env.step(state.atari_state, action)
new_episode_return = state.episode_returns + reward
new_episode_length = state.episode_lengths + 1
state = LogState(
atari_state=atari_state,
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns=state.returned_episode_returns * (1 - done)
+ new_episode_return * done,
returned_episode_lengths=state.returned_episode_lengths * (1 - done)
+ new_episode_length * done,
)
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode"] = done
return obs, state, reward, done, info
[docs]
@struct.dataclass
class MultiRewardLogState:
atari_state: Any # Can be any of the states from wrappers above
episode_returns_env: float
episode_returns: chex.Array
episode_lengths: int
returned_episode_returns_env: float
returned_episode_returns: chex.Array
returned_episode_lengths: int
[docs]
class MultiRewardLogWrapper(JaxatariWrapper):
"""Log the episode returns and lengths for multiple rewards.
Make sure to apply MultiRewardWrapper to the core env when using this wrapper.
The final logs will be 'returned_episode_returns_0', ... for each reward function provided.
"""
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def reset(
self, key: chex.PRNGKey,
) -> Tuple[chex.Array, MultiRewardLogState]:
obs, atari_state = self._env.reset(key)
# Dummy step to get info structure
_, _, _, _, dummy_info = self._env.step(atari_state, 0)
rewards_shape_provider = dummy_info.get("all_rewards", jnp.zeros(1))
episode_returns_init = jnp.zeros_like(rewards_shape_provider)
state = MultiRewardLogState(atari_state, 0.0, episode_returns_init, 0, 0.0, episode_returns_init, 0)
return obs, state
[docs]
@functools.partial(jax.jit, static_argnums=(0,))
def step(
self,
state: MultiRewardLogState,
action: Union[int, float],
) -> Tuple[chex.Array, MultiRewardLogState, float, bool, Dict[Any, Any]]:
obs, atari_state, reward, done, info = self._env.step(state.atari_state, action)
new_episode_return_env = state.episode_returns_env + reward
# Safely get all_rewards, defaulting to a zero array that matches the shape of our tracker.
all_rewards_step = info.get("all_rewards", jnp.zeros_like(state.episode_returns))
new_episode_return = state.episode_returns + all_rewards_step
new_episode_length = state.episode_lengths + 1
state = MultiRewardLogState(
atari_state=atari_state,
episode_returns_env=new_episode_return_env * (1 - done),
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns_env=state.returned_episode_returns_env * (1 - done)
+ new_episode_return_env * done,
returned_episode_returns=state.returned_episode_returns * (1 - done)
+ new_episode_return * done,
returned_episode_lengths=state.returned_episode_lengths * (1 - done)
+ new_episode_length * done,
)
info["returned_episode_env_returns"] = state.returned_episode_returns_env
for i, r in enumerate(new_episode_return):
info[f"returned_episode_returns_{i}"] = state.returned_episode_returns[i]
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode"] = done
return obs, state, reward, done, info