Source code for jaxatari.games.jax_freeway

import os
from functools import partial
import chex
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from typing import Tuple, NamedTuple, List, Dict, Optional, Any

from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action
import jaxatari.spaces as spaces
from jaxatari.renderers import JAXGameRenderer
from jaxatari.rendering import jax_rendering_utils as render_utils

def _get_default_asset_config() -> tuple:
    """
    Returns the default declarative asset manifest for Freeway.
    Kept immutable (tuple of dicts) to fit NamedTuple defaults.
    Note: Recolorings are added dynamically in the renderer based on constants.
    """
    return (
        {'name': 'background', 'type': 'background', 'file': 'background.npy'},
        {
            'name': 'player', 'type': 'group',
            'files': ['player_hit.npy', 'player_walk.npy', 'player_idle.npy']
        },
        {'name': 'car_dark_red', 'type': 'single', 'file': 'car_dark_red.npy'},
        {'name': 'car_light_green', 'type': 'single', 'file': 'car_light_green.npy'},
        {'name': 'car_dark_green', 'type': 'single', 'file': 'car_dark_green.npy'},
        {'name': 'car_light_red', 'type': 'single', 'file': 'car_light_red.npy'},
        {'name': 'car_blue', 'type': 'single', 'file': 'car_blue.npy'},
        {'name': 'car_brown', 'type': 'single', 'file': 'car_brown.npy'},
        {'name': 'car_light_blue', 'type': 'single', 'file': 'car_light_blue.npy'},
        {'name': 'car_red', 'type': 'single', 'file': 'car_red.npy'},
        {'name': 'car_green', 'type': 'single', 'file': 'car_green.npy'},
        {'name': 'car_yellow', 'type': 'single', 'file': 'car_yellow.npy'},
        {'name': 'score_digits', 'type': 'digits', 'pattern': 'score_{}.npy'},
    )
 
"""Per-lane car movement timing (frames per pixel, sign = direction).
Negative values move left, positive values move right. Absolute value is the
frame interval at which the car advances by one pixel.
THIS IS THE CONSTANT THAT DEFINES THE 10 DIFFERENT PATTERNS.
"""
CAR_UPDATE: List[int] = [
    -5,  # Lane 0
    -4,  # Lane 1
    -3,  # Lane 2
    -2,  # Lane 3
    -1,  # Lane 4
    1,   # Lane 5
    2,   # Lane 6
    3,   # Lane 7
    4,   # Lane 8
    5,   # Lane 9
]

 

[docs] class FreewayConstants(NamedTuple): screen_width: int = 160 screen_height: int = 210 chicken_width: int = 6 chicken_height: int = 8 chicken_x: int = 44 # Fixed x position car_width: int = 8 car_height: int = 10 num_lanes: int = 10 lane_spacing: int = 16 car_speeds: List[float] = None lane_borders: List[int] = None top_border: int = 15 top_path: int = 8 bottom_border: int = 180 # Collision response tuning throw_back_frames: int = 24 # frames the chicken is pushed back after hit stun_frames: int = 28 # frames the chicken cannot move after hit # After scoring (reaching the top and resetting), prevent movement for N frames post_score_stun_frames: int = 28 # Vertical offset to apply to chicken spawn after scoring (positive = lower on screen) post_score_spawn_offset_y: int = 1 # Collision box insets (shrink AABB without changing render sizes) chicken_hit_inset_x: int = 1 chicken_hit_inset_y_top: int = -2 # Top edge of chicken (when cars approach from above) chicken_hit_inset_y_bottom: int = 0 # Bottom edge of chicken (when cars approach from below) car_hit_inset_x: int = 0 car_hit_inset_y_top: int = 2 # Top edge of car (for cars approaching from above) car_hit_inset_y_bottom: int = 0 # Bottom edge of car (for cars approaching from below) # Fine-tune horizontal respawn offset applied when wrapping # Positive shifts right-moving lanes further right on re-entry (and vice versa for left-moving) respawn_offset: int = 8 # Fine-tune vertical car alignment within each lane (applied at reset) car_y_offset: int = 1 # Per-lane cadence phase offset (frames) for N-frame movement; allows aligning cadence to reference cadence_phase_offset: List[int] = ( -2, -2, -2, 0, 0, 0, 0, -2, -2, 3 ) # This list defines the period and direction for each lane's pattern CAR_UPDATES: List[int] = [ -5, # Lane 0 -4, # Lane 1 -3, # Lane 2 -2, # Lane 3 -1, # Lane 4 1, # Lane 5 2, # Lane 6 3, # Lane 7 4, # Lane 8 5, # Lane 9 ] # Per-lane initial phase offsets in pixels to align with ALE (applied to x at reset) # Lanes 0-4 move left; lanes 5-9 move right lane_phase_offset: List[int] = [ 5, # lane 0 (+5 px) 5, # lane 1 5, # lane 2 5, # lane 3 6, # lane 4 156, # lane 5 (+157 px) 157, # lane 6 157, # lane 7 157, # lane 8 157, # lane 9 ] # Upper 5 lanes move left (-), lower 5 lanes move right (+) # Value at i is the frequency in which car at lane i moves one pixel lane_borders = [ top_border + top_path, # Lane 0 1 * lane_spacing + (top_border + top_path), # Lane 1 2 * lane_spacing + (top_border + top_path), # Lane 2 3 * lane_spacing + (top_border + top_path), # Lane 3 4 * lane_spacing + (top_border + top_path), # Lane 4 5 * lane_spacing + (top_border + top_path), # Lane 5 6 * lane_spacing + (top_border + top_path), # Lane 6 7 * lane_spacing + (top_border + top_path), # Lane 7 8 * lane_spacing + (top_border + top_path), # Lane 8 9 * lane_spacing + (top_border + top_path), # Lane 10 10 * lane_spacing + (top_border + top_path) + 2, # Lane 10 ] # Car colors for each lane (10 lanes). If color is None, use original sprite color. # Otherwise, recolor the car sprite to the specified RGB color. # Note: Use None for original color, (0, 0, 0) for actual black. CAR_COLORS: List[Optional[Tuple[int, int, int]]] = [ None, # Lane 0 - use original color None, # Lane 1 - use original color None, # Lane 2 - use original color None, # Lane 3 - use original color None, # Lane 4 - use original color None, # Lane 5 - use original color None, # Lane 6 - use original color None, # Lane 7 - use original color None, # Lane 8 - use original color None, # Lane 9 - use original color ] # Game Duration Config # Original Atari 2600 timer logic results in exactly 8192 frames game_duration_frames: int = 8192 # Score starts blinking at 2:00 (7680 frames) to warn players of imminent game over blink_start_frames: int = 7680 # Rate at which score colors cycle (frames per color change) score_blink_rate: int = 2 # Colors for the blinking score cycle (RGB) SCORE_BLINK_COLORS: List[Tuple[int, int, int]] = ( (210, 210, 64), # Yellow (original) (210, 64, 64), # Red (64, 210, 64), # Green (64, 64, 210), # Blue (210, 64, 210), # Magenta (64, 210, 210), # Cyan ) # Asset config baked into constants (immutable default) for asset overrides ASSET_CONFIG: tuple = _get_default_asset_config()
[docs] class FreewayState(NamedTuple): """Represents the current state of the game""" chicken_y: chex.Array cars: chex.Array # Shape: (num_lanes, 2) for x,y positions (ints for render/collide) # Per-lane cadence counters (frames), advance independently to sync movement patterns per lane lane_time: chex.Array score: chex.Array time: chex.Array cooldown: chex.Array # Cooldown after collision walking_frames: chex.Array lives_lost: chex.Array game_over: chex.Array
[docs] class EntityPosition(NamedTuple): x: jnp.ndarray y: jnp.ndarray width: jnp.ndarray height: jnp.ndarray
[docs] class FreewayObservation(NamedTuple): chicken: EntityPosition car: jnp.ndarray # Shape: (10, 4) with x,y,width,height for each car
[docs] class FreewayInfo(NamedTuple): time: jnp.ndarray
[docs] class JaxFreeway(JaxEnvironment[FreewayState, FreewayObservation, FreewayInfo, FreewayConstants]): # Map agent action indices (0, 1, 2) to ALE console actions # 0 -> NOOP, 1 -> UP, 2 -> DOWN ACTION_SET: jnp.ndarray = jnp.array( [Action.NOOP, Action.UP, Action.DOWN], dtype=jnp.int32 ) def __init__(self, consts: FreewayConstants = None): if consts is None: consts = FreewayConstants() super().__init__(consts) self.state = self.reset() self.renderer = FreewayRenderer(self.consts)
[docs] def reset(self, key: jax.random.PRNGKey = None) -> Tuple[FreewayObservation, FreewayState]: """Initialize a new game state""" # Start chicken at bottom chicken_y = self.consts.bottom_border + self.consts.chicken_height - 1 # Initialize one car per lane cars = [] for lane in range(self.consts.num_lanes): lane_y = ( self.consts.lane_borders[lane] + int(self.consts.lane_spacing / 2) - int(self.consts.car_height / 2) ) + int(self.consts.car_y_offset) # Upper 5 lanes start from right, lower 5 lanes start from left if lane < 5: x = ( self.consts.screen_width - self.consts.car_width + 0 ) # Start from right else: x = 0 # Start from left cars.append([x, lane_y]) cars = jnp.array(cars, dtype=jnp.int32) # Apply per-lane phase offsets phase = jnp.array(self.consts.lane_phase_offset, dtype=jnp.int32) cars = cars.at[:, 0].add(phase) # Initialize per-lane cadence counters using configured phase offsets periods0 = jnp.abs(jnp.array(self.consts.CAR_UPDATES, dtype=jnp.int32)) phases0 = jnp.array(self.consts.cadence_phase_offset, dtype=jnp.int32) % periods0 state = FreewayState( chicken_y=jnp.array(chicken_y, dtype=jnp.int32), cars=cars, lane_time=phases0, score=jnp.array(0, dtype=jnp.int32), time=jnp.array(0, dtype=jnp.int32), cooldown=jnp.array(0, dtype=jnp.int32), walking_frames=jnp.array(0, dtype=jnp.int32), lives_lost=jnp.array(0, dtype=jnp.int32), game_over=jnp.array(False, dtype=jnp.bool_), ) return self._get_observation(state), state
[docs] @partial(jax.jit, static_argnums=(0,)) def step(self, state: FreewayState, action: int) -> tuple[FreewayObservation, FreewayState, float, bool, FreewayInfo]: """Take a step in the game given an action""" # Translate compact agent action (0, 1, 2) to ALE console action constant atari_action = jnp.take(self.ACTION_SET, action) # Update chicken position if not in cooldown dy = jnp.where( jnp.logical_and( state.cooldown > self.consts.stun_frames, state.cooldown <= (self.consts.stun_frames + self.consts.throw_back_frames) ), 1.0, jnp.where( atari_action == Action.UP, -1.0, jnp.where(atari_action == Action.DOWN, 1.0, 0.0), ), ) dy = jnp.where( jnp.logical_and(state.cooldown > 0, state.cooldown <= self.consts.stun_frames), 0.0, dy, ) # add one to the walking frames if dy != 0, if it is 0 reset to 0 new_walking_frames = jnp.where(dy != 0, state.walking_frames + 1, 0) # reset new_walking frames at 8 new_walking_frames = jnp.where(new_walking_frames >= 8, 0, new_walking_frames) new_y = jnp.clip( state.chicken_y + dy.astype(jnp.int32), self.consts.top_border, self.consts.bottom_border + self.consts.chicken_height - 1, ).astype(jnp.int32) # Implements the [0, 0, ..., 1] repeating pattern based on CAR_UPDATES periods = jnp.abs(jnp.array(self.consts.CAR_UPDATES, dtype=jnp.int32)) signs = jnp.sign(jnp.array(self.consts.CAR_UPDATES, dtype=jnp.int32)) # Per-lane cadence counters: move when the counter reaches period-1, then keep advancing should_move_mask = (state.lane_time == (periods - 1)) delta_x = jnp.where(should_move_mask, signs, 0).astype(jnp.int32) # Apply moves to integer x positions pre_x = state.cars[:, 0] x_int = pre_x + delta_x # Wrap positions to [-car_width, screen_width) range_len_i = self.consts.screen_width + self.consts.car_width x_int_wrapped = ((x_int + self.consts.car_width) % range_len_i) - self.consts.car_width # Apply respawn offset only to entries that wrapped this frame wrapped_right = jnp.logical_and(signs > 0, x_int >= self.consts.screen_width) wrapped_left = jnp.logical_and(signs < 0, x_int < -self.consts.car_width) offset = jnp.asarray(self.consts.respawn_offset, dtype=jnp.int32) adjusted = x_int_wrapped adjusted = jnp.where(wrapped_right, x_int_wrapped + offset, adjusted) adjusted = jnp.where(wrapped_left, x_int_wrapped - offset, adjusted) # Keep within valid range after offset adjusted = jnp.clip(adjusted, -self.consts.car_width, self.consts.screen_width - 1) # Update integer car positions new_cars = state.cars.at[:, 0].set(adjusted.astype(jnp.int32)) # Advance per-lane cadence counters new_lane_time = (state.lane_time + 1) % periods # Check for collisions def check_collision(car_pos): car_x, car_y = car_pos # Chicken AABB with insets cxi = jnp.asarray(self.consts.chicken_hit_inset_x, dtype=jnp.int32) cyi_top = jnp.asarray(self.consts.chicken_hit_inset_y_top, dtype=jnp.int32) cyi_bottom = jnp.asarray(self.consts.chicken_hit_inset_y_bottom, dtype=jnp.int32) ch_x0 = self.consts.chicken_x + cxi ch_x1 = self.consts.chicken_x + self.consts.chicken_width - cxi ch_y0 = state.chicken_y - self.consts.chicken_height + cyi_top ch_y1 = state.chicken_y - cyi_bottom # Car AABB with insets kxi = jnp.asarray(self.consts.car_hit_inset_x, dtype=jnp.int32) kyi_top = jnp.asarray(self.consts.car_hit_inset_y_top, dtype=jnp.int32) kyi_bottom = jnp.asarray(self.consts.car_hit_inset_y_bottom, dtype=jnp.int32) car_x0 = car_x + kxi car_x1 = car_x + self.consts.car_width - kxi car_y0 = car_y - self.consts.car_height + kyi_top car_y1 = car_y - kyi_bottom overlap_x = jnp.logical_and(ch_x0 < car_x1, ch_x1 > car_x0) overlap_y = jnp.logical_and(ch_y0 < car_y1, ch_y1 > car_y0) return jnp.logical_and(overlap_x, overlap_y) # Check collisions for all cars collisions = jax.vmap(check_collision)(new_cars) any_collision = jnp.any(collisions) any_collision = jax.lax.cond( state.cooldown > 0, lambda _: False, lambda _: any_collision, operand=None ) # Update cooldown new_cooldown = jnp.where( any_collision, self.consts.throw_back_frames + self.consts.stun_frames, jnp.maximum(0, state.cooldown - 1), ).astype(jnp.int32) # Update score if chicken reaches top new_score = jnp.where( new_y <= self.consts.top_border, state.score + 1, state.score ).astype(jnp.int32) # Reset chicken position if scored scored = new_y <= self.consts.top_border new_y = jnp.where( scored, self.consts.bottom_border + self.consts.chicken_height - 1 + self.consts.post_score_spawn_offset_y, new_y, ).astype(jnp.int32) # Apply a post-score stun to prevent immediate movement after crossing once new_cooldown = jnp.where( scored, jnp.maximum(new_cooldown, jnp.asarray(self.consts.post_score_stun_frames, dtype=jnp.int32)), new_cooldown, ) # Update time new_time = (state.time + 1).astype(jnp.int32) # Check game over based on exact frame count game_over = jnp.where( new_time >= self.consts.game_duration_frames, jnp.array(True), state.game_over, ) new_live_lost = jnp.where( any_collision, state.lives_lost + 1, state.lives_lost, ) new_state = FreewayState( chicken_y=new_y, cars=new_cars, lane_time=new_lane_time, score=new_score, time=new_time, cooldown=new_cooldown, walking_frames=new_walking_frames.astype(jnp.int32), lives_lost=new_live_lost, game_over=game_over, ) done = self._get_done(new_state) env_reward = self._get_reward(state, new_state) obs = self._get_observation(new_state) info = self._get_info(new_state) return obs, new_state, env_reward, done, info
@partial(jax.jit, static_argnums=(0,)) def _get_observation(self, state: FreewayState): # create chicken chicken = EntityPosition( x=jnp.array(self.consts.chicken_x, dtype=jnp.int32), y=state.chicken_y, width=jnp.array(self.consts.chicken_width, dtype=jnp.int32), height=jnp.array(self.consts.chicken_height, dtype=jnp.int32), ) # create cars cars = jnp.zeros((self.consts.num_lanes, 4), dtype=jnp.int32) for i in range(self.consts.num_lanes): car_pos = state.cars.at[i].get() cars = cars.at[i].set( jnp.array( [ car_pos.at[0].get(), # x position car_pos.at[1].get(), # y position self.consts.car_width, # width self.consts.car_height, # height ], dtype=jnp.int32 ) ) return FreewayObservation(chicken=chicken, car=cars) @partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: FreewayState) -> FreewayInfo: return FreewayInfo(time=state.time) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: FreewayState, state: FreewayState): return state.score - previous_state.score @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: FreewayState) -> bool: return state.game_over
[docs] def action_space(self) -> spaces.Discrete: """Returns the action space for Freeway.""" return spaces.Discrete(len(self.ACTION_SET))
[docs] def observation_space(self) -> spaces.Dict: """Returns the observation space for Freeway. The observation contains: - chicken: EntityPosition (x, y, width, height) - car: array of shape (10, 4) with x,y,width,height for each car """ return spaces.Dict({ "chicken": spaces.Dict({ "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), "car": spaces.Box(low=0, high=210, shape=(10, 4), dtype=jnp.int32), })
[docs] def image_space(self) -> spaces.Box: """Returns the image space for Freeway. The image is a RGB image with shape (210, 160, 3). """ return spaces.Box( low=0, high=255, shape=(210, 160, 3), dtype=jnp.uint8 )
[docs] def render(self, state: FreewayState) -> jnp.ndarray: """Render the game state to a raster image.""" return self.renderer.render(state)
[docs] def obs_to_flat_array(self, obs: FreewayObservation) -> jnp.ndarray: """Convert observation to a flat array.""" # Flatten chicken position and dimensions chicken_flat = jnp.concatenate([ obs.chicken.x.reshape(-1), obs.chicken.y.reshape(-1), obs.chicken.width.reshape(-1), obs.chicken.height.reshape(-1) ]) # Flatten car positions and dimensions cars_flat = obs.car.reshape(-1) # Concatenate all components return jnp.concatenate([chicken_flat, cars_flat]).astype(jnp.int32)
[docs] class FreewayRenderer(JAXGameRenderer): def __init__(self, consts: FreewayConstants = None): super().__init__() self.consts = consts or FreewayConstants() self.config = render_utils.RendererConfig( game_dimensions=(210, 160), channels=3, #downscale=(84, 84) ) self.jr = render_utils.JaxRenderingUtils(self.config) # Load and setup assets using the new pattern # Convert tuple to list so we can modify it asset_config = list(self.consts.ASSET_CONFIG) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/freeway" # Map lane index to car sprite name car_sprite_names = [ 'car_dark_red', # Lane 0 'car_light_green', # Lane 1 'car_dark_green', # Lane 2 'car_light_red', # Lane 3 'car_blue', # Lane 4 'car_brown', # Lane 5 'car_light_blue', # Lane 6 'car_red', # Lane 7 'car_green', # Lane 8 'car_yellow', # Lane 9 ] # Add recoloring rules to car assets if colors are specified for i, asset in enumerate(asset_config): if asset.get('name') in car_sprite_names: lane_idx = car_sprite_names.index(asset['name']) color = self.consts.CAR_COLORS[lane_idx] if color is not None: # Add recoloring rule: global replace with target color if 'recolorings' not in asset: asset['recolorings'] = {} asset['recolorings']['recolored'] = {'target': color} # Add recoloring rules to score_digits for blink colors for i, asset in enumerate(asset_config): if asset.get('name') == 'score_digits': if 'recolorings' not in asset: asset['recolorings'] = {} # Add a recolored variant for each blink color for idx, color in enumerate(self.consts.SCORE_BLINK_COLORS): asset['recolorings'][f'blink_{idx}'] = {'target': color} break # Create black bar sprite at initialization time black_bar_sprite = self._create_black_bar_sprite() # Append procedural assets asset_config.append({ 'name': 'black_bar', 'type': 'procedural', 'data': black_bar_sprite }) # Load all assets, create palette, and generate ID masks ( self.PALETTE, self.SHAPE_MASKS, self.BACKGROUND, self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) # Setup score masks tensor from recolored variants self.score_masks_tensor = self._setup_score_masks_tensor() def _setup_score_masks_tensor(self) -> jnp.ndarray: """ Creates a tensor of score digit masks for all required color palettes. Index 0: Default color Index 1..N: Blinking colors Returns: jnp.ndarray: Shape (NumPalettes, 10, H, W) """ # 1. Get default masks (already stacked as (10, H, W) from load_and_setup_assets) default_masks = self.SHAPE_MASKS["score_digits"] # Ensure it's a stacked array if it's a list if isinstance(default_masks, list): default_masks = jnp.stack(default_masks) all_palettes = [default_masks] # 2. Get recolored masks for each blink color (created via recoloring system) for idx in range(len(self.consts.SCORE_BLINK_COLORS)): blink_key = f"score_digits_blink_{idx}" if blink_key in self.SHAPE_MASKS: # Masks are already stacked as (10, H, W) from load_and_setup_assets blink_masks = self.SHAPE_MASKS[blink_key] if isinstance(blink_masks, list): blink_masks = jnp.stack(blink_masks) all_palettes.append(blink_masks) else: # Fallback: if recoloring didn't create the variant, use default all_palettes.append(default_masks) # 3. Stack all palettes into one master tensor # Shape: (NumColors+1, 10, H, W) return jnp.stack(all_palettes) def _create_black_bar_sprite(self) -> jnp.ndarray: """Create a black bar sprite for the left side of the screen.""" # Create an 8-pixel wide black bar covering the full height bar_height = self.consts.screen_height bar_width = 8 # Create black sprite with full alpha (255) so it gets added to palette black_bar = jnp.zeros((bar_height, bar_width, 4), dtype=jnp.uint8) black_bar = black_bar.at[:, :, 3].set(255) # Set alpha to 255 return black_bar
[docs] @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) # Draw fixed chicken (right side - "Computer") chicken_idle_mask = self.SHAPE_MASKS["player"][2] raster = self.jr.render_at(raster, 110, self.consts.bottom_border + self.consts.chicken_height - 1, chicken_idle_mask) # Draw active chicken (left side - Player 1) use_idle = state.walking_frames < 4 chicken_frame_index = jax.lax.select(use_idle, 2, 1) is_hit = state.cooldown > 0 chicken_frame_index = jax.lax.select( jnp.logical_and(is_hit, jnp.logical_or((state.cooldown % 8) < 4, state.cooldown < 30)), 0, chicken_frame_index ) chicken_mask = self.SHAPE_MASKS["player"][chicken_frame_index] raster = self.jr.render_at(raster, self.consts.chicken_x, state.chicken_y, chicken_mask) # Draw cars car_sprite_names = [ 'car_dark_red', # Lane 0 'car_light_green', # Lane 1 'car_dark_green', # Lane 2 'car_light_red', # Lane 3 'car_blue', # Lane 4 'car_brown', # Lane 5 'car_light_blue', # Lane 6 'car_red', # Lane 7 'car_green', # Lane 8 'car_yellow', # Lane 9 ] for i in range(self.consts.num_lanes): sprite_name = car_sprite_names[i] # Use recolored variant if color is specified, otherwise use original if self.consts.CAR_COLORS[i] is not None: mask_key = f"{sprite_name}_recolored" else: mask_key = sprite_name car_mask = self.SHAPE_MASKS[mask_key] raster = self.jr.render_at_clipped(raster, state.cars[i, 0], state.cars[i, 1], car_mask) # --- SCORE RENDERING --- should_blink = state.time >= self.consts.blink_start_frames blink_cycle_idx = (state.time // self.consts.score_blink_rate) % len(self.consts.SCORE_BLINK_COLORS) # Use direct access for default (matches original behavior) or tensor for blinking # This ensures exact compatibility with the original version when not blinking def get_default_masks(): return self.SHAPE_MASKS["score_digits"] def get_blink_masks(): palette_index = blink_cycle_idx + 1 return self.score_masks_tensor[palette_index] current_score_masks = jax.lax.cond( should_blink, get_blink_masks, get_default_masks ) # 1. Player 1 Score (Left) score_digits_p1 = self.jr.int_to_digits(state.score, max_digits=2) is_single_digit_p1 = state.score < 10 start_index_p1 = jax.lax.select(is_single_digit_p1, 1, 0) num_to_render_p1 = jax.lax.select(is_single_digit_p1, 1, 2) render_x_p1 = jax.lax.select(is_single_digit_p1, 49, 41) raster = self.jr.render_label_selective( raster, render_x_p1, 5, score_digits_p1, current_score_masks, start_index_p1, num_to_render_p1, spacing=8 ) # 2. Player 2 / Computer Score (Right - Dummy 0 for now) # Position logic: Right chicken is at 110. Offset is similar to left (110 + 5ish) # Center of right lane roughly 115. score_digits_p2 = self.jr.int_to_digits(0, max_digits=1) # Always 0 render_x_p2 = 113 # Fixed position for "0" # Render '00' on the right side. raster = self.jr.render_label_selective( raster, render_x_p2, 5, score_digits_p2, current_score_masks, 0, 1, spacing=8 ) # Draw black bar black_bar_mask = self.SHAPE_MASKS["black_bar"] raster = self.jr.render_at(raster, 0, 0, black_bar_mask) return self.jr.render_from_palette(raster, self.PALETTE)