Source code for jaxatari.core

import importlib
import inspect
import warnings

from jaxatari.environment import JaxEnvironment
from jaxatari.renderers import JAXGameRenderer
from jaxatari.modification import apply_modifications
from jaxatari.wrappers import JaxatariWrapper

from . import check_ownership


def _warn_deprecated_obs_to_flat_array(env: JaxEnvironment) -> None:
    """Warn if legacy obs_to_flat_array is present on the environment."""
    if hasattr(env, "obs_to_flat_array") and callable(getattr(env, "obs_to_flat_array")):
        warnings.warn(
            "Environment exposes deprecated obs_to_flat_array(). "
            "Observations should now be flax.struct.dataclasses using ObjectObservation "
            "for objects or plain arrays for observations like lives, score, etc. "
            "Depending on legacy obs_to_flat_array might lead to unforseen issues with wrappers.",
            DeprecationWarning,
            stacklevel=2,
        )



# Map of game names to their module paths (commented out games are WIP and will be supported in the near future)
GAME_MODULES = {
    "amidar": "jaxatari.games.jax_amidar",
    # "airraid": "jaxatari.games.jax_airraid",
    "alien": "jaxatari.games.jax_alien",
    "asterix": "jaxatari.games.jax_asterix",
    "asteroids": "jaxatari.games.jax_asteroids",
    "atlantis": "jaxatari.games.jax_atlantis",
    "bankheist": "jaxatari.games.jax_bankheist",
    "beamrider": "jaxatari.games.jax_beamrider",
    "berzerk": "jaxatari.games.jax_berzerk",
    "blackjack": "jaxatari.games.jax_blackjack",
    "breakout": "jaxatari.games.jax_breakout",
    "centipede": "jaxatari.games.jax_centipede",
    "choppercommand": "jaxatari.games.jax_choppercommand",
    "enduro": "jaxatari.games.jax_enduro",
    "fishingderby": "jaxatari.games.jax_fishingderby",
    "flagcapture": "jaxatari.games.jax_flagcapture",
    "freeway": "jaxatari.games.jax_freeway",
    "frostbite": "jaxatari.games.jax_frostbite",
    "galaxian": "jaxatari.games.jax_galaxian",
    "gravitar": "jaxatari.games.jax_gravitar",
    # "hangman": "jaxatari.games.jax_hangman",
    "hauntedhouse": "jaxatari.games.jax_hauntedhouse",
    "humancannonball": "jaxatari.games.jax_humancannonball",
    "kangaroo": "jaxatari.games.jax_kangaroo",
    "kingkong": "jaxatari.games.jax_kingkong",
    # "klax": "jaxatari.games.jax_klax",
    "lasergates": "jaxatari.games.jax_lasergates",
    "namethisgame": "jaxatari.games.jax_namethisgame",
    "phoenix": "jaxatari.games.jax_phoenix",
    "pong": "jaxatari.games.jax_pong",
    "qbert": "jaxatari.games.jax_qbert",
    "riverraid": "jaxatari.games.jax_riverraid",
    "seaquest": "jaxatari.games.jax_seaquest",
    "sirlancelot": "jaxatari.games.jax_sirlancelot",
    "skiing": "jaxatari.games.jax_skiing",
    "slotmachine": "jaxatari.games.jax_slotmachine",
    "spaceinvaders": "jaxatari.games.jax_spaceinvaders",
    "spacewar": "jaxatari.games.jax_spacewar",
    # "surround": "jaxatari.games.jax_surround", currently not in a state that can be used
    "tennis": "jaxatari.games.jax_tennis",
    "tetris": "jaxatari.games.jax_tetris",
    "timepilot": "jaxatari.games.jax_timepilot",
    "tron": "jaxatari.games.jax_tron",
    "turmoil": "jaxatari.games.jax_turmoil",
    "venture": "jaxatari.games.jax_venture",
    # "videocheckers": "jaxatari.games.jax_videocheckers",
    "videocube": "jaxatari.games.jax_videocube",
    "videopinball": "jaxatari.games.jax_videopinball",
    "wordzapper": "jaxatari.games.jax_wordzapper",
    "mspacman": "jaxatari.games.jax_mspacman",
    "montezumarevenge": "jaxatari.games.jax_montezumarevenge",
    # "pacman": "jaxatari.games.jax_pacman",
    # Add new games here
}

# Mod modules registry: for each game, provide the Controller class path
MOD_MODULES = {
    "pong": "jaxatari.games.mods.pong_mods.PongEnvMod",
    "kangaroo": "jaxatari.games.mods.kangaroo_mods.KangarooEnvMod",
    "freeway": "jaxatari.games.mods.freeway_mods.FreewayEnvMod",
    "breakout": "jaxatari.games.mods.breakout_mods.BreakoutEnvMod",
    "seaquest": "jaxatari.games.mods.seaquest_mods.SeaquestEnvMod",
    "videopinball": "jaxatari.games.mods.videopinball_mods.VideoPinballEnvMod",
    'tennis': "jaxatari.games.mods.tennis_mods.TennisEnvMod",
    "fishingderby": "jaxatari.games.mods.fishingderby_mods.FishingDerbyEnvMod",
    "atlantis": "jaxatari.games.mods.atlantis_mods.AtlantisEnvMod",
    "bankheist": "jaxatari.games.mods.bankheist_mods.BankHeistEnvMod",
    "montezumarevenge": "jaxatari.games.mods.montezuma_revenge_mods.MontezumaRevengeEnvMod",
    "frostbite": "jaxatari.games.mods.frostbite_mods.FrostbiteEnvMod",
    "gravitar": "jaxatari.games.mods.gravitar_mods.GravitarEnvMod",
    "phoenix": "jaxatari.games.mods.phoenix_mods.PhoenixEnvMod",
    "enduro": "jaxatari.games.mods.enduro_mods.EnduroEnvMod",
    "qbert": "jaxatari.games.mods.qbert_mods.QbertEnvMod",
    "mspacman": "jaxatari.games.mods.mspacman_mods.MsPacmanEnvMod",
    "beamrider": "jaxatari.games.mods.beamrider_mods.BeamRiderEnvMod",
    "venture": "jaxatari.games.mods.venture_mods.VentureEnvMod",
    "spaceinvaders": "jaxatari.games.mods.spaceinvaders_mods.SpaceInvadersEnvMod",
    "skiing": "jaxatari.games.mods.skiing_mods.SkiingEnvMod",
    "alien": "jaxatari.games.mods.alien_mods.AlienEnvMod",
    "asteroids": "jaxatari.games.mods.asteroids_mods.AsteroidsEnvMod"
}


[docs] def list_available_games() -> list[str]: """Lists all available, registered games.""" return list(GAME_MODULES.keys())
[docs] def make(game_name: str, mode: int = 0, difficulty: int = 0, mods_config: list = None, # deprecated, output warning if its used mods: list = None, allow_conflicts: bool = False ) -> JaxEnvironment: """ Creates and returns a JaxAtari game environment instance. This is the main entry point for creating environments. If 'mods' is provided, this function applies the full two-stage modding pipeline: 1. Pre-scans for constant overrides. 2. Instantiates the base env with modded constants. 3. Applies the internal 'JaxAtariModController'. 4. Wraps the env with the 'JaxAtariModWrapper'. Args: game_name: Name of the game to load (e.g., "pong"). mode: Game mode. difficulty: Game difficulty. mods: List of modifications to apply (default: None). allow_conflicts: Whether to allow conflicting mods (default: False). Returns: An instance of the specified game environment. """ check_ownership() # Ensure ownership confirmed if isinstance(game_name, str): game_name = game_name.lower() if mods_config is not None: warnings.warn( "'mods_config' is deprecated and will be removed in future versions. " "Please use 'mods' instead.", DeprecationWarning ) mods = mods_config if game_name not in GAME_MODULES: raise NotImplementedError( f"The game '{game_name}' does not exist. Available games: {list_available_games()}" ) try: # 1. Load the base environment class module = importlib.import_module(GAME_MODULES[game_name]) env_class = None for _, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, JaxEnvironment) and obj is not JaxEnvironment: env_class = obj break if env_class is None: raise ImportError(f"No JaxEnvironment subclass found in {GAME_MODULES[game_name]}") # 2. Mods need default consts for pre-scan; otherwise a single env_class() is enough. if mods: try: base_consts = env_class().consts env = apply_modifications( game_name=game_name, mods_config=mods, allow_conflicts=allow_conflicts, base_consts=base_consts, env_class=env_class, MOD_MODULES=MOD_MODULES ) _warn_deprecated_obs_to_flat_array(env) return env except NotImplementedError as e: # Mod module not defined for this game - fall back to base environment warnings.warn( f"Mods requested for '{game_name}' but no mod module is available. " f"Creating base environment without mods. Error: {e}", UserWarning ) env = env_class() _warn_deprecated_obs_to_flat_array(env) return env except (ImportError, NotImplementedError) as e: # Only wrap registration/import errors - let intentional errors (ValueError, etc.) propagate raise ImportError(f"Failed to load game '{game_name}': {e}") from e