import importlib
import inspect
from jaxatari.environment import JaxEnvironment
from jaxatari.renderers import JAXGameRenderer
from jaxatari.modification import apply_modifications
from jaxatari.wrappers import JaxatariWrapper
# Map of game names to their module paths
GAME_MODULES = {
"asterix": "jaxatari.games.jax_asterix",
"asteroids": "jaxatari.games.jax_asteroids",
"atlantis": "jaxatari.games.jax_atlantis",
"bankheist": "jaxatari.games.jax_bankheist",
"berzerk": "jaxatari.games.jax_berzerk",
"blackjack": "jaxatari.games.jax_blackjack",
"breakout": "jaxatari.games.jax_breakout",
"centipede": "jaxatari.games.jax_centipede",
"choppercommand": "jaxatari.games.jax_choppercommand",
"enduro": "jaxatari.games.jax_enduro",
"fishingderby": "jaxatari.games.jax_fishingderby",
"freeway": "jaxatari.games.jax_freeway",
"frostbite": "jaxatari.games.jax_frostbite",
"galaxian": "jaxatari.games.jax_galaxian",
"haunted_house": "jaxatari.games.jax_haunted_house",
"human_cannonball": "jaxatari.games.jax_human_cannonball",
"kangaroo": "jaxatari.games.jax_kangaroo",
"kingkong": "jaxatari.games.jax_kingkong",
"klax": "jaxatari.games.jax_klax",
"lasergates": "jaxatari.games.jax_lasergates",
"namethisgame": "jaxatari.games.jax_namethisgame",
"phoenix": "jaxatari.games.jax_phoenix",
"pong": "jaxatari.games.jax_pong",
"riverraid": "jaxatari.games.jax_riverraid",
"seaquest": "jaxatari.games.jax_seaquest",
"sir_lancelot": "jaxatari.games.jax_sir_lancelot",
"skiing": "jaxatari.games.jax_skiing",
"slotmachine": "jaxatari.games.jax_slotmachine",
"spaceinvaders": "jaxatari.games.jax_spaceinvaders",
"spacewar": "jaxatari.games.jax_spacewar",
"surround": "jaxatari.games.jax_surround",
"tennis": "jaxatari.games.jax_tennis",
"tetris": "jaxatari.games.jax_tetris",
"timepilot": "jaxatari.games.jax_timepilot",
"tron": "jaxatari.games.jax_tron",
"turmoil": "jaxatari.games.jax_turmoil",
"videocheckers": "jaxatari.games.jax_videocheckers",
"videocube": "jaxatari.games.jax_videocube",
"videopinball": "jaxatari.games.jax_videopinball",
"wordzapper": "jaxatari.games.jax_wordzapper",
# Add new games here
}
# Mod modules registry: for each game, provide the Controller class path
MOD_MODULES = {
"pong": "jaxatari.games.mods.pong_mods.PongEnvMod",
"kangaroo": "jaxatari.games.mods.kangaroo_mods.KangarooEnvMod",
"freeway": "jaxatari.games.mods.freeway_mods.FreewayEnvMod",
"breakout": "jaxatari.games.mods.breakout_mods.BreakoutEnvMod",
"seaquest": "jaxatari.games.mods.seaquest_mods.SeaquestEnvMod",
}
[docs]
def list_available_games() -> list[str]:
"""Lists all available, registered games."""
return list(GAME_MODULES.keys())
[docs]
def make(game_name: str,
mode: int = 0,
difficulty: int = 0,
mods_config: list = None,
allow_conflicts: bool = False
) -> JaxEnvironment:
"""
Creates and returns a JaxAtari game environment instance.
This is the main entry point for creating environments.
If 'mods_config' is provided, this function applies the
full two-stage modding pipeline:
1. Pre-scans for constant overrides.
2. Instantiates the base env with modded constants.
3. Applies the internal 'JaxAtariModController'.
4. Wraps the env with the 'JaxAtariModWrapper'.
Args:
game_name: Name of the game to load (e.g., "pong").
mode: Game mode.
difficulty: Game difficulty.
Returns:
An instance of the specified game environment.
"""
if game_name not in GAME_MODULES:
raise NotImplementedError(
f"The game '{game_name}' does not exist. Available games: {list_available_games()}"
)
try:
# 1. Load the base environment class
module = importlib.import_module(GAME_MODULES[game_name])
env_class = None
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, JaxEnvironment) and obj is not JaxEnvironment:
env_class = obj
break
if env_class is None:
raise ImportError(f"No JaxEnvironment subclass found in {GAME_MODULES[game_name]}")
# 2. Get default constants
base_consts = env_class().consts
# 3. Handle mods if requested
if mods_config:
return apply_modifications(
game_name=game_name,
mods_config=mods_config,
allow_conflicts=allow_conflicts,
base_consts=base_consts,
env_class=env_class,
MOD_MODULES=MOD_MODULES
)
# No mods: return default base env with default constants
return env_class(consts=base_consts)
except (ImportError, NotImplementedError) as e:
# Only wrap registration/import errors - let intentional errors (ValueError, etc.) propagate
raise ImportError(f"Failed to load game '{game_name}': {e}") from e
[docs]
def make_renderer(game_name: str) -> JAXGameRenderer:
"""
Creates and returns a JaxAtari game environment renderer.
Args:
game_name: Name of the game to load (e.g., "pong").
Returns:
An instance of the specified game environment renderer.
"""
if game_name not in GAME_MODULES:
raise NotImplementedError(
f"The game '{game_name}' does not exist. Available games: {list_available_games()}"
)
try:
# 1. Dynamically load the module
module = importlib.import_module(GAME_MODULES[game_name])
# 2. Find the correct environment class within the module
renderer_class = None
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, JAXGameRenderer) and obj is not JAXGameRenderer:
renderer_class = obj
break # Found it
if renderer_class is None:
raise ImportError(f"No AXGameRenderer subclass found in {GAME_MODULES[game_name]}")
# 3. Instantiate the class, passing along the arguments, and return it
return renderer_class()
except (ImportError, AttributeError) as e:
raise ImportError(f"Failed to load renderer for '{game_name}': {e}") from e