Basic Usage

Environment Creation

The main entry point is the make() function:

import jax
import jaxatari

# Create an environment
env = jaxatari.make("pong")  # or "seaquest", "kangaroo", "freeway", etc.

# Get available games
available_games = jaxatari.list_available_games()
print(f"Available games: {available_games}")

Using Modifications

JAXAtari provides pre-implemented game modifications to test agent generalization:

import jaxatari

# Pong environment with the lazy_enemy mod
mod_env = jaxatari.make("pong", mods=["lazy_enemy"])

# Multiple mods can be applied simultaneously
mod_env = jaxatari.make("pong", mods=["lazy_enemy", "shift_enemy"])

Custom modifications are well supported via the JaxAtariModController. Feel free to share them by opening a PR.

Using Wrappers

JAXAtari provides a comprehensive wrapper system for different observation types:

import jaxatari
from jaxatari.wrappers import (
    AtariWrapper,
    ObjectCentricWrapper,
    PixelObsWrapper,
    PixelAndObjectCentricWrapper,
    FlattenObservationWrapper,
    LogWrapper,
)

base_env = jaxatari.make("pong")
atari_env = AtariWrapper(base_env)

env = ObjectCentricWrapper(atari_env)          # object-centric features
# OR
env = PixelObsWrapper(atari_env)               # pixel observations
# OR
env = PixelAndObjectCentricWrapper(atari_env)  # both
# OR
env = FlattenObservationWrapper(ObjectCentricWrapper(atari_env))  # flattened

# Add logging wrapper for training
env = LogWrapper(env)

Vectorized Stepping

JAXAtari is designed for massive parallelization via jax.vmap and jax.lax.scan:

import jax
import jaxatari
from jaxatari.wrappers import AtariWrapper, ObjectCentricWrapper, FlattenObservationWrapper

base_env = jaxatari.make("pong")
env = FlattenObservationWrapper(ObjectCentricWrapper(AtariWrapper(base_env)))

n_envs = 1024
rng = jax.random.PRNGKey(0)
reset_keys = jax.random.split(rng, n_envs)

# Initialize n_envs parallel environments
init_obs, env_state = jax.vmap(env.reset)(reset_keys)

# Take one random step in each env
action = jax.random.randint(rng, (n_envs,), 0, env.action_space().n)
new_obs, new_env_state, reward, terminated, truncated, info = jax.vmap(env.step)(env_state, action)

# Take 100 steps with scan
def step_fn(carry, unused):
    _, env_state = carry
    new_obs, new_env_state, reward, terminated, truncated, info = jax.vmap(env.step)(env_state, action)
    return (new_obs, new_env_state), (reward, terminated, truncated, info)

carry = (init_obs, env_state)
_, (rewards, terminations, truncations, infos) = jax.lax.scan(
    step_fn, carry, None, length=100
)

Manual Play

To play a game manually with keyboard input, install pygame and use the provided script:

pip install pygame
python3 scripts/play.py -g Pong