from enum import Enum
from typing import Tuple, Generic, TypeVar
import jax.numpy as jnp
import jax.random as jrandom
import warnings
from jaxatari.spaces import Space
from flax import struct
EnvObs = TypeVar("EnvObs")
EnvState = TypeVar("EnvState")
EnvInfo = TypeVar("EnvInfo")
EnvConstants = TypeVar("EnvConstants")
[docs]
class JAXAtariAction:
"""
"Namespace" for Atari action integer constants.
These are directly usable in JAX arrays.
"""
NOOP: int = 0
FIRE: int = 1
UP: int = 2
RIGHT: int = 3
LEFT: int = 4
DOWN: int = 5
UPRIGHT: int = 6
UPLEFT: int = 7
DOWNRIGHT: int = 8
DOWNLEFT: int = 9
UPFIRE: int = 10
RIGHTFIRE: int = 11
LEFTFIRE: int = 12
DOWNFIRE: int = 13
UPRIGHTFIRE: int = 14
UPLEFTFIRE: int = 15
DOWNRIGHTFIRE: int = 16
DOWNLEFTFIRE: int = 17
[docs]
@classmethod
def get_all_values(cls) -> jnp.ndarray:
# For fixed action sets, explicit listing is safest and clearest.
return jnp.array([
cls.NOOP, cls.FIRE, cls.UP, cls.RIGHT, cls.LEFT, cls.DOWN,
cls.UPRIGHT, cls.UPLEFT, cls.DOWNRIGHT, cls.DOWNLEFT,
cls.UPFIRE, cls.RIGHTFIRE, cls.LEFTFIRE, cls.DOWNFIRE,
cls.UPRIGHTFIRE, cls.UPLEFTFIRE, cls.DOWNRIGHTFIRE, cls.DOWNLEFTFIRE
], dtype=jnp.int32)
[docs]
@struct.dataclass
class ObjectObservation:
"""
Dataclass for object centric observations of objects in jaxatari environments.
Can hold 1 to N objects of the same type (for example 12 sharks in seaquest or 1 player ship in asteroids).
Should always be instantiated via the create() classmethod to ensure proper default handling.
Attributes:
x: x position of the object.
y: y position of the object.
width: width of the object.
height: height of the object.
active: whether the object is currently active.
"""
x: jnp.ndarray # obligatory (int8)
y: jnp.ndarray # obligatory (int8)
width: jnp.ndarray # obligatory (int8)
height: jnp.ndarray # obligatory (int8)
# --- Additional attributes (will be set to 0 if not used) ---
active: jnp.ndarray = struct.field(default_factory=lambda: jnp.array(1)) # whether the object is currently active (0 or 1)
visual_id: jnp.ndarray = struct.field(default_factory=lambda: jnp.array(0)) # visual identifier of the object (different color sprites for example)
state: jnp.ndarray = struct.field(default_factory=lambda: jnp.array(0)) # state of the object, for example is the ghost in pacman vulnerable [blinking] or not [static] (format depends on game, see the game docs)
orientation: jnp.ndarray = struct.field(default_factory=lambda: jnp.array(0)) # angle of the object (format depends on game, see the game docs)
[docs]
@classmethod
def create(cls, x, y, width, height, active=None, visual_id=None, state=None, orientation=None):
# Helper to handle defaults
if active is None: active = jnp.ones_like(x, dtype=jnp.int32)
if visual_id is None: visual_id = jnp.zeros_like(x, dtype=jnp.int32)
if state is None: state = jnp.zeros_like(x, dtype=jnp.int32)
if orientation is None: orientation = jnp.zeros_like(x, dtype=jnp.int32)
return cls(x=x, y=y, width=width, height=height, active=active, visual_id=visual_id, state=state, orientation=orientation)
def __repr__(self):
try:
# Handle scalar case (0-d arrays)
if self.x.ndim == 0:
try:
# Try to get concrete values for cleaner output
x, y = int(self.x), int(self.y)
w, h = int(self.width), int(self.height)
act = int(self.active)
ori = float(self.orientation)
st = int(self.state)
vid = int(self.visual_id)
status = "ACTIVE" if act else "INACTIVE"
return (f"Object(Single, {status}): Pos=({x}, {y}) | Size=({w}, {h}) | "
f"Ori={ori:.1f} | State={st} | VisID={vid}")
except:
# Fallback for Tracers
return f"Object(Single): Pos=({self.x}, {self.y}) | Active={self.active}"
# Handle vector case (1-d arrays)
n = self.x.shape[0]
lines = [f"ObjectGroup(count={n}):"]
# Limit print length if too huge
limit = min(n, 20)
for i in range(limit):
try:
# Try to extract concrete values
act = int(self.active[i])
status = "ACTIVE" if act else " - " # Dim inactive ones
x, y = int(self.x[i]), int(self.y[i])
w, h = int(self.width[i]), int(self.height[i])
ori = float(self.orientation[i])
st = int(self.state[i])
vid = int(self.visual_id[i])
# Formatted table row
line = (f" [{i:2d}] {status} | Pos: ({x:3d}, {y:3d}) | Size: ({w:2d}, {h:2d}) | "
f"Ori: {ori:5.1f} | State: {st:2d} | VisID: {vid:2d}")
except:
# Fallback for Tracers
line = f" [{i}] Active={self.active[i]} | Pos=({self.x[i]}, {self.y[i]})"
lines.append(line)
if n > limit:
lines.append(f" ... ({n - limit} more objects) ...")
return "\n".join(lines)
except Exception as e:
return f"ObjectObservation(Error in __repr__: {e})"
[docs]
class JaxEnvironment(Generic[EnvState, EnvObs, EnvInfo, EnvConstants]):
"""
Abstract class for a JAX environment.
Generics:
EnvState: The type of the environment state.
EnvObs: The type of the observation.
EnvInfo: The type of the additional information.
EnvConstants: The type of the environment constants.
"""
def __init__(self, consts: EnvConstants = None):
if consts is not None:
# Check for legacy NamedTuple usage (has _fields but is not a PyTreeNode)
is_named_tuple = isinstance(consts, tuple) and hasattr(consts, '_fields')
# Check if it's a Flax PyTreeNode (flax.struct.dataclass instances)
try:
from flax import struct
is_flax_node = isinstance(consts, struct.PyTreeNode)
except (ImportError, AttributeError):
is_flax_node = False
if is_named_tuple and not is_flax_node:
warnings.warn(
f"Performance Warning: {self.__class__.__name__}.consts is a 'NamedTuple'. "
"This prevents JAX from treating constants as static metadata, potentially causing excessive recompilation. "
"Future versions will require 'flax.struct.PyTreeNode' (and the states/observations/info to flax.struct.dataclass/PyTreeNode). "
"Please refactor your constants class.",
UserWarning,
stacklevel=2
)
self.consts = consts
# --- MODDING INFRASTRUCTURE ---
# Functional: Tracks which renderer methods mods have patched.
# Used by wrappers to safely transfer patches during renderer swaps.
self._patched_renderer_methods = []
# Functional: Explicit registry of jitted callables that must be invalidated
# when renderer hot-swaps occur (e.g., native downscaling).
self._jit_invalidation_targets = []
# Functional: mutation epoch + tripwire controls for detecting risky
# post-trace monkeypatching.
self._jit_mutation_epoch = 0
self._jit_tripwire_enabled = True
# Informational: Structured audit log of every change made by the mod system.
# Machine-parseable: dict of category -> set of names that were changed.
# Categories: "attribute", "method", "constant", "asset".
self._mod_history = {
"attribute": set(),
"method": set(),
"constant": set(),
"asset": set(),
}
[docs]
def reset(self, key: jrandom.PRNGKey=None) -> Tuple[EnvObs, EnvState]:
"""
Resets the environment to the initial state.
Returns: The initial observation and the initial environment state.
"""
raise NotImplementedError("Abstract method")
[docs]
def step(
self, state: EnvState, action
) -> Tuple[EnvObs, EnvState, float, bool, EnvInfo]:
"""
Takes a step in the environment.
Args:
state: The current environment state.
action: The action to take.
Returns: The observation, the new environment state, the reward, whether the state is terminal, and additional info.
"""
raise NotImplementedError("Abstract method")
[docs]
def render(self, state: EnvState) -> Tuple[jnp.ndarray]:
"""
Renders the environment state to a single image.
Args:
state: The environment state.
Returns: A single image of the environment state.
"""
raise NotImplementedError("Abstract method")
[docs]
def action_space(self) -> Space:
"""
Returns the action space of the environment as an array containing the actions that can be taken.
Returns: The action space of the environment as an array.
"""
raise NotImplementedError("Abstract method")
[docs]
def observation_space(self) -> Space:
"""
Returns the observation space of the environment.
Returns: The observation space of the environment.
"""
raise NotImplementedError("Abstract method")
[docs]
def image_space(self) -> Space:
"""
Returns the image space of the environment.
Returns: The image space of the environment.
"""
raise NotImplementedError("Abstract method")
def _get_observation(self, state: EnvState) -> EnvObs:
"""
Converts the environment state to the observation by filtering out non-relevant information.
Args:
state: The environment state.
Returns: observation
"""
raise NotImplementedError("Abstract method")
def _get_info(self, state: EnvState, all_rewards: jnp.array = None) -> EnvInfo:
"""
Extracts information from the environment state that is not relevant for the agent.
Args:
state: The environment state.
Returns: info
"""
raise NotImplementedError("Abstract method")
def _get_reward(self, previous_state: EnvState, state: EnvState) -> float:
"""
Calculates the reward from the environment state.
Args:
previous_state: The previous environment state.
state: The environment state.
Returns: reward
"""
raise NotImplementedError("Abstract method")
def _get_done(self, state: EnvState) -> bool:
"""
Determines if the environment state is a terminal state
Args:
state: The environment state.
Returns: True if the state is terminal, False otherwise.
"""
raise NotImplementedError("Abstract method")