Source code for jaxatari.games.jax_seaquest

import os
from functools import partial
from typing import Tuple, NamedTuple
import jax
import jax.numpy as jnp
import numpy as np
import chex

import jaxatari.spaces as spaces
from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action
from jaxatari.renderers import JAXGameRenderer
from jaxatari.rendering import jax_rendering_utils as render_utils

def _get_default_asset_config() -> tuple:
    """
    Returns the default declarative asset manifest for Seaquest.
    Kept immutable (tuple of dicts) to fit NamedTuple defaults.
    """
    return (
        {'name': 'background', 'type': 'background', 'file': 'bg/1.npy'},
        {'name': 'player_sub', 'type': 'group', 'files': ['player_sub/1.npy', 'player_sub/2.npy', 'player_sub/3.npy']},
        {'name': 'diver', 'type': 'group', 'files': ['diver/1.npy', 'diver/2.npy']},
        {'name': 'shark_base', 'type': 'group', 'files': ['shark/1.npy', 'shark/2.npy']},
        {'name': 'enemy_sub', 'type': 'group', 'files': ['enemy_sub/1.npy', 'enemy_sub/2.npy', 'enemy_sub/3.npy']},
        {'name': 'player_torp', 'type': 'single', 'file': 'player_torp/1.npy'},
        {'name': 'enemy_torp', 'type': 'single', 'file': 'enemy_torp/1.npy'},
        {'name': 'life_indicator', 'type': 'single', 'file': 'life_indicator/1.npy'},
        {'name': 'diver_indicator', 'type': 'single', 'file': 'diver_indicator/1.npy'},
        {'name': 'digits', 'type': 'digits', 'pattern': 'digits/{}.npy'},
    )


[docs] class SeaquestConstants(NamedTuple): # Colors BACKGROUND_COLOR = (0, 0, 139) # Dark blue for water PLAYER_COLOR = (187, 187, 53) # Yellow for player sub DIVER_COLOR = (66, 72, 200) # Pink for divers SHARK_DIFFICULTY_COLORS = jnp.array( [ [92, 186, 92], # Level 0: Base green [213, 130, 74], # Level 1: Orange (adjusted from original ROM) [ 170, 92, 170, ], # Level 2: Purple (adjusted from original ROM COLOR_KILLER_SHARK_02) [213, 92, 130], # Level 3: Pink (adjusted from original ROM) [186, 92, 92], # Level 4: Red (adjusted from original ROM) ] ) ENEMY_SUB_COLOR = (170, 170, 170) # Gray for enemy subs OXYGEN_BAR_COLOR = (214, 214, 214, 255) # White for oxygen SCORE_COLOR = (210, 210, 64) # Score color OXYGEN_TEXT_COLOR = (0, 0, 0) # Black for oxygen text # Object sizes and initial positions from RAM state PLAYER_SIZE = (16, 11) # Width, Height DIVER_SIZE = (8, 11) SHARK_SIZE = (8, 7) ENEMY_SUB_SIZE = (8, 11) MISSILE_SIZE = (8, 1) PLAYER_START_X = 76 PLAYER_START_Y = 46 X_BORDERS = (0, 160) PLAYER_BOUNDS = (21, 134), (46, 141) # Maximum number of objects (from MAX_NB_OBJECTS) MAX_DIVERS = 4 MAX_SHARKS = 12 MAX_SUBS = 12 MAX_ENEMY_MISSILES = 4 MAX_PLAYER_TORPS = 1 MAX_SURFACE_SUBS = 1 MAX_COLLECTED_DIVERS = 6 # define object orientations FACE_LEFT = -1 FACE_RIGHT = 1 SPAWN_POSITIONS_Y = jnp.array([71, 95, 119, 139]) # submarines at y=69? SUBMARINE_Y_OFFSET = 2 ENEMY_MISSILE_Y = jnp.array([73, 97, 121, 141]) # missile x = submarine.x + 4 DIVER_SPAWN_POSITIONS = jnp.array([69, 93, 117, 141]) MISSILE_SPAWN_POSITIONS = jnp.array([39, 126]) # Right, Left # First wave directions from original code FIRST_WAVE_DIRS = jnp.array([False, False, False, True]) # Asset config baked into constants (immutable default) for asset overrides ASSET_CONFIG: tuple = _get_default_asset_config()
[docs] class SpawnState(NamedTuple): difficulty: chex.Array # Current difficulty level (0-7) lane_dependent_pattern: chex.Array # Track waves independently per lane [4 lanes] to_be_spawned: ( chex.Array ) # tracks which enemies are still in the spawning cycle [4 lanes * 3 slots] -> necessary due to the spaced out spawning of multiple enemies survived: ( chex.Array ) # track if last enemy survived [4 lanes * 3 slots] -> 1 if survived whilst going right, 0 if not, -1 if survived whilst going left prev_sub: chex.Array # Track previous entity type for each lane [4 lanes] spawn_timers: chex.Array # Individual spawn timers per lane [4 lanes] diver_array: ( chex.Array ) # Track which divers are still in the spawning cycle [4 lanes] lane_directions: ( chex.Array ) # Track lane directions for each wave [4 lanes] -> 0 = right, 1 = left
# Game state container
[docs] class SeaquestState(NamedTuple): player_x: chex.Array player_y: chex.Array player_direction: chex.Array # 0 for right, 1 for left oxygen: chex.Array divers_collected: chex.Array score: chex.Array lives: chex.Array spawn_state: SpawnState diver_positions: chex.Array # (4, 3) array for divers shark_positions: ( chex.Array ) # (12, 3) array for sharks - separated into 4 lanes, 3 slots per lane [left to right] sub_positions: ( chex.Array ) # (12, 3) array for enemy subs - separated into 4 lanes, 3 slots per lane [left to right] enemy_missile_positions: ( chex.Array ) # (4, 3) array for enemy missiles (only the front boats can shoot) surface_sub_position: chex.Array # (1, 3) array for surface submarine player_missile_position: ( chex.Array ) # (1, 3) array for player missile (x, y, direction) step_counter: chex.Array just_surfaced: chex.Array # Flag for tracking actual surfacing moment successful_rescues: ( chex.Array ) # Number of times the player has surfaced with all six divers death_counter: chex.Array # Counter for tracking death animation rng_key: chex.PRNGKey
[docs] class PlayerEntity(NamedTuple): x: jnp.ndarray y: jnp.ndarray o: jnp.ndarray width: jnp.ndarray height: jnp.ndarray active: jnp.ndarray
[docs] class EntityPosition(NamedTuple): x: jnp.ndarray y: jnp.ndarray width: jnp.ndarray height: jnp.ndarray active: jnp.ndarray
[docs] class SeaquestObservation(NamedTuple): player: PlayerEntity sharks: jnp.ndarray # Shape (12, 5) - 12 sharks, each with x,y,w,h,active submarines: jnp.ndarray # Shape (12, 5) divers: jnp.ndarray # Shape (4, 5) enemy_missiles: jnp.ndarray # Shape (4, 5) surface_submarine: EntityPosition player_missile: EntityPosition collected_divers: jnp.ndarray # Number of divers collected (0-6) player_score: jnp.ndarray lives: jnp.ndarray oxygen_level: jnp.ndarray # Oxygen level (0-255)
[docs] class SeaquestInfo(NamedTuple): difficulty: jnp.ndarray # Current difficulty level successful_rescues: jnp.ndarray # Number of successful rescues step_counter: jnp.ndarray # Current step count
[docs] class CarryState(NamedTuple): missile_pos: chex.Array shark_pos: chex.Array sub_pos: chex.Array score: chex.Array
# RENDER CONSTANTS
[docs] def get_shark_color_index(difficulty: chex.Array) -> chex.Array: """ Determine which shark color to use based on difficulty level. Color cycle: Green -> Yellow -> Pink -> Orange -> Green -> Yellow -> Green -> Orange -> back to start Args: difficulty: Current difficulty level (0-7) Returns: Color index: 0=Green, 1=Yellow, 2=Pink, 3=Orange """ # Map difficulty to color index using the specific 8-level pattern # Pattern: Green -> Yellow -> Pink -> Orange -> Green -> Yellow -> Green -> Orange -> back to start color_mapping = jnp.array([0, 1, 2, 3, 0, 1, 0, 3]) # 0=Green, 1=Yellow, 2=Pink, 3=Orange color_index = jnp.take(color_mapping, difficulty % 8) return color_index
[docs] class JaxSeaquest(JaxEnvironment[SeaquestState, SeaquestObservation, SeaquestInfo, SeaquestConstants]):
[docs] def initialize_spawn_state(self) -> SpawnState: """Initialize spawn state with first wave matching original game.""" return SpawnState( difficulty=jnp.array(0), lane_dependent_pattern=jnp.zeros( 4, dtype=jnp.int32 ), # Each lane starts at wave 0 to_be_spawned=jnp.zeros( 12, dtype=jnp.int32 ), # Track which enemies are still in the spawning cycle survived=jnp.zeros(12, dtype=jnp.int32), # Track which enemies survived prev_sub=jnp.zeros( 4, dtype=jnp.int32 ), # Track previous entity type (0 if shark, 1 if sub) -> starts at 1 since the first wave is sharks spawn_timers=jnp.array( [277, 277, 277, 277 + 60], dtype=jnp.int32 ), # All lanes start with same timer diver_array=jnp.array([1, 1, 0, 0], dtype=jnp.int32), lane_directions=self.consts.FIRST_WAVE_DIRS.astype(jnp.int32), # First wave directions )
[docs] def soft_reset_spawn_state(self, spawn_state: SpawnState) -> SpawnState: """Reset spawn_times""" return spawn_state._replace( spawn_timers=jnp.array([277, 277, 277, 277], dtype=jnp.int32) )
[docs] @partial(jax.jit, static_argnums=(0,)) def check_collision_single(self, pos1, size1, pos2, size2): """Check collision between two single entities""" # Calculate edges for rectangle 1 rect1_left = pos1[0] rect1_right = pos1[0] + size1[0] rect1_top = pos1[1] rect1_bottom = pos1[1] + size1[1] # Calculate edges for rectangle 2 rect2_left = pos2[0] rect2_right = pos2[0] + size2[0] rect2_top = pos2[1] rect2_bottom = pos2[1] + size2[1] # Check overlap horizontal_overlap = jnp.logical_and( rect1_left < rect2_right, rect1_right > rect2_left ) vertical_overlap = jnp.logical_and( rect1_top < rect2_bottom, rect1_bottom > rect2_top ) return jnp.logical_and(horizontal_overlap, vertical_overlap)
[docs] @partial(jax.jit, static_argnums=(0,)) def check_collision_batch(self, pos1, size1, pos2_array, size2): """Check collision between one entity and an array of entities""" # Calculate edges for rectangle 1 rect1_left = pos1[0] rect1_right = pos1[0] + size1[0] rect1_top = pos1[1] rect1_bottom = pos1[1] + size1[1] # Calculate edges for all rectangles in pos2_array rect2_left = pos2_array[:, 0] rect2_right = pos2_array[:, 0] + size2[0] rect2_top = pos2_array[:, 1] rect2_bottom = pos2_array[:, 1] + size2[1] # Check overlap for all entities horizontal_overlaps = jnp.logical_and( rect1_left < rect2_right, rect1_right > rect2_left ) vertical_overlaps = jnp.logical_and( rect1_top < rect2_bottom, rect1_bottom > rect2_top ) # Combine checks for each entity collisions = jnp.logical_and(horizontal_overlaps, vertical_overlaps) # Return true if any collision detected return jnp.any(collisions)
[docs] @partial(jax.jit, static_argnums=(0,)) def check_missile_collisions( self, missile_pos: chex.Array, shark_positions: chex.Array, sub_positions: chex.Array, score: chex.Array, successful_rescues: chex.Array, spawn_state: SpawnState, rng_key: chex.PRNGKey, ) -> tuple[chex.Array, chex.Array, chex.Array, chex.Array, SpawnState, chex.PRNGKey]: """ Check for collisions between player missile and enemies using a vectorized approach. """ missile_rect_pos = missile_pos[:2] missile_active = missile_pos[2] != 0 # --- 1. Vectorized Collision Detection --- all_enemies = jnp.concatenate([shark_positions, sub_positions], axis=0) enemy_sizes = jnp.concatenate([ jnp.repeat(jnp.array(self.consts.SHARK_SIZE)[None, :], shark_positions.shape[0], axis=0), jnp.repeat(jnp.array(self.consts.ENEMY_SUB_SIZE)[None, :], sub_positions.shape[0], axis=0) ], axis=0) def check_single_enemy(enemy_pos, enemy_size): return self.check_collision_single( missile_rect_pos, self.consts.MISSILE_SIZE, enemy_pos[:2], enemy_size ) all_collision_mask = jax.vmap(check_single_enemy, in_axes=(0, 0))(all_enemies, enemy_sizes) all_collision_mask = jnp.logical_and(missile_active, all_collision_mask) shark_collision_mask = all_collision_mask[:shark_positions.shape[0]] sub_collision_mask = all_collision_mask[shark_positions.shape[0]:] # --- 2. Update Game State Based on Collision Masks --- points_per_kill = self.calculate_kill_points(successful_rescues) score_increase = jnp.sum(all_collision_mask * points_per_kill) new_score = score + score_increase zeros = jnp.zeros_like(shark_positions[0]) new_shark_positions = jnp.where(shark_collision_mask[:, None], zeros, shark_positions) new_sub_positions = jnp.where(sub_collision_mask[:, None], zeros, sub_positions) missile_was_destroyed = jnp.any(all_collision_mask) new_missile_pos = jnp.where(missile_was_destroyed, jnp.zeros(3), missile_pos) # --- 3. Update SpawnState Based on Collision Masks --- # The survived array is (12,), so we need to merge collision results from both # shark and sub slots. is_sub_mask_slots = jnp.repeat(spawn_state.prev_sub.astype(bool), 3) # (4,) -> (12,) final_collision_mask = jnp.where(is_sub_mask_slots, sub_collision_mask, shark_collision_mask) new_survived = jnp.where(final_collision_mask, 0, spawn_state.survived) # Determine which *lanes* had a collision across all 8 virtual lanes. lane_had_collision_8 = jnp.any(all_collision_mask.reshape(8, 3), axis=1) # Shape (8,) # Merge the 8-lane collision results into a 4-lane mask shark_lanes_hit, sub_lanes_hit = lane_had_collision_8[:4], lane_had_collision_8[4:] lane_had_collision_4 = jnp.where(spawn_state.prev_sub.astype(bool), sub_lanes_hit, shark_lanes_hit) # Shape (4,) # Update Spawn Timers for the 4 physical lanes new_spawn_timers = jnp.where(lane_had_collision_4, 200, spawn_state.spawn_timers) # Update Lane Directions for the 4 physical lanes rng_key, dir_rng_key = jax.random.split(rng_key) random_directions = jax.random.bernoulli(dir_rng_key, 0.5, (4,)).astype(jnp.int32) new_lane_directions = jnp.where( lane_had_collision_4, random_directions, spawn_state.lane_directions ) new_spawn_state = spawn_state._replace( survived=new_survived, spawn_timers=new_spawn_timers, lane_directions=new_lane_directions, ) return ( new_missile_pos, new_shark_positions, new_sub_positions, new_score, new_spawn_state, rng_key, )
[docs] @partial(jax.jit, static_argnums=(0,)) def check_player_collision( self, player_x, player_y, submarine_list, shark_list, surface_sub_pos, enemy_projectile_list, score, successful_rescues, ) -> Tuple[chex.Array, chex.Array]: # check if the player has collided with any of the three given lists # the player is a 16x11 rectangle # the submarine is a 8x11 rectangle # the shark is a 8x7 rectangle # the missile is a 8x1 rectangle # the surface submarine is 8x11 as well # check if the player has collided with any of the submarines submarine_collisions = jnp.any( self.check_collision_batch( jnp.array([player_x, player_y]), self.consts.PLAYER_SIZE, submarine_list, self.consts.ENEMY_SUB_SIZE ) ) # check if the player has collided with any of the sharks shark_collisions = jnp.any( self.check_collision_batch( jnp.array([player_x, player_y]), self.consts.PLAYER_SIZE, shark_list, self.consts.SHARK_SIZE ) ) # check if the player collided with the surface submarine surface_collision = self.check_collision_single( jnp.array([player_x, player_y]), self.consts.PLAYER_SIZE, surface_sub_pos, self.consts.ENEMY_SUB_SIZE ) # check if the player has collided with any of the enemy projectiles missile_collisions = jnp.any( self.check_collision_batch( jnp.array([player_x, player_y]), self.consts.PLAYER_SIZE, enemy_projectile_list, self.consts.MISSILE_SIZE ) ) # Calculate points for collisions. # When colliding with a shark or submarine the player gains points similar to killing the object collision_points = jnp.where( shark_collisions, self.calculate_kill_points(successful_rescues), jnp.where( submarine_collisions, self.calculate_kill_points(successful_rescues), jnp.where(surface_collision, self.calculate_kill_points(successful_rescues), 0), ), ) return ( jnp.any( jnp.array( [ submarine_collisions, shark_collisions, missile_collisions, surface_collision, ] ) ), collision_points, )
[docs] @partial(jax.jit, static_argnums=(0,)) def get_spawn_position(self, moving_left: chex.Array, slot: chex.Array) -> chex.Array: """Get spawn position based on movement direction and slot number""" base_y = jnp.array(self.consts.SPAWN_POSITIONS_Y[slot]) x_pos = jnp.where( moving_left, jnp.array(165, dtype=jnp.int32), # Start right if moving left jnp.array(0, dtype=jnp.int32), ) # Start left if moving right direction = jnp.where(moving_left, -1, 1) # -1 for left, 1 for right return jnp.array([x_pos, base_y, direction], dtype=jnp.int32)
[docs] @partial(jax.jit, static_argnums=(0,)) def is_slot_empty(self, pos: chex.Array) -> chex.Array: """Check if a position slot is empty (0,0,ß)""" return pos[2] == 0
[docs] @partial(jax.jit, static_argnums=(0,)) def get_front_entity(self, i, lane_positions): # check on the first submarine in the lane which direction they are going direction = lane_positions[0][2] direction = jnp.where( lane_positions[0][2] == 0, jnp.where( lane_positions[1][2] == 0, lane_positions[2][2], lane_positions[1][2] ), lane_positions[0][2], ) # if direction is 1, go from right to left until an active entity is found # if direction is -1, go from left to right until an active entity is found front_entity = jnp.where( direction == -1, jnp.where( lane_positions[0][2] != 0, lane_positions[0], jnp.where( lane_positions[1][2] != 0, lane_positions[1], jnp.where(lane_positions[2][2] != 0, lane_positions[2], jnp.zeros(3)), ), ), jnp.where( lane_positions[2][2] != 0, lane_positions[2], jnp.where( lane_positions[1][2] != 0, lane_positions[1], jnp.where(lane_positions[0][2] != 0, lane_positions[0], jnp.zeros(3)), ), ), ) return front_entity
[docs] @partial(jax.jit, static_argnums=(0,)) def get_pattern_for_difficulty( self, current_pattern: chex.Array, moving_left: chex.Array ) -> chex.Array: """Returns spawn pattern based on the lane's current wave/pattern number Pattern meanings: 0: Single enemy (initial pattern) 1: Two adjacent enemies 2: Two enemies with gap 3: Three enemies in a row """ # Basic pattern arrays for different formations PATTERNS = jnp.array( [ [0, 0, 1], # wave 0: Single enemy [0, 1, 1], # wave 1: Two adjacent [1, 0, 1], # wave 2: Two with gap [1, 1, 1], # wave 3: Three in row ] ) # Reverse pattern if moving left base_pattern = PATTERNS[current_pattern] return base_pattern
[docs] @partial(jax.jit, static_argnums=(0,)) def update_enemy_spawns( self, spawn_state: SpawnState, shark_positions: chex.Array, sub_positions: chex.Array, diver_positions: chex.Array, step_counter: chex.Array, rng: chex.PRNGKey = None, ) -> Tuple[SpawnState, chex.Array, chex.Array, chex.PRNGKey]: """Update enemy spawns using pattern-based system matching original game. Args: spawn_state: Current spawn state shark_positions: Current shark positions sub_positions: Current submarine positions diver_positions: Current diver positions step_counter: Current step counter rng: Optional random key for direction randomization Returns: Tuple of updated spawn state, shark positions, sub positions, and updated RNG key """ new_spawn_timers = jnp.where( spawn_state.spawn_timers > 0, spawn_state.spawn_timers - 1, spawn_state.spawn_timers, ) new_state = spawn_state._replace(spawn_timers=new_spawn_timers) # --- START of new vectorized calculation --- # 1. Vectorized check for empty lanes across all 4 lanes sharks_active = shark_positions.reshape(4, 3, 3)[:, :, 2] != 0 subs_active = sub_positions.reshape(4, 3, 3)[:, :, 2] != 0 all_lanes_empty = jnp.all(~sharks_active & ~subs_active, axis=1) # Shape (4,) # 2. Vectorized check for entities that still need to be spawned to_be_spawned_lanes = spawn_state.to_be_spawned.reshape(4, 3) any_to_be_spawned = jnp.any(to_be_spawned_lanes != 0, axis=1) # Shape (4,) # 3. Combine conditions to create a mask of all lanes that need an update all_lanes_need_update = jnp.logical_or(all_lanes_empty, any_to_be_spawned) # Shape (4,) # --- END of new vectorized calculation --- # The scan_lanes function is now much simpler def scan_lanes(carry, lane_idx): curr_state, curr_shark_positions, curr_sub_positions, curr_diver_positions, curr_rng = carry # Use the pre-computed mask to check if this lane needs an update needs_update = all_lanes_need_update[lane_idx] # The rest of the function proceeds as before new_carry = jax.lax.cond( needs_update, lambda x: process_lane(lane_idx, x), # process_lane is unchanged lambda x: x, (curr_state, curr_shark_positions, curr_sub_positions, curr_diver_positions, curr_rng), ) return new_carry, None def initialize_new_spawn_cycle(i, carry): spawn_state, shark_positions, sub_positions, diver_positions, rng = carry # Split RNG key for this lane rng, lane_rng = jax.random.split(rng) # Get survived status for this lane (3 slots) lane_survived = jax.lax.dynamic_slice(spawn_state.survived, (i * 3,), (3,)) # Update the difficulty patterns for this lane left_over = jnp.any(lane_survived) clipped_difficulty = spawn_state.difficulty % 8 # Update spawn state lane_specific_pattern = jnp.where( jnp.logical_not(left_over), # Only update if all destroyed jnp.where( clipped_difficulty < 2, 0, jnp.where( clipped_difficulty < 4, 1, jnp.where( clipped_difficulty < 6, 2, jnp.where(clipped_difficulty < 8, 3, 0), ), ), ), spawn_state.lane_dependent_pattern[i], ) # Check if there's an active diver in this lane active_diver = diver_positions[i][2] != 0 diver_direction = diver_positions[i][2] # If there's an active diver, use its direction, otherwise randomize moving_left = jnp.where( active_diver, diver_direction == -1, # Use diver's direction if active spawn_state.lane_directions[i] == 1 # Otherwise use current lane direction ) # get the spawn pattern for this lane # Check if this slot had something survive last time (if yes, we have to overwrite the current_pattern) current_pattern = jnp.where( left_over, lane_survived, self.get_pattern_for_difficulty(lane_specific_pattern, moving_left), ) # make sure that in the current pattern all entries are positive (i.e. abs() on all values) current_pattern = jnp.abs(current_pattern) # in case we are going left, flip the pattern current_pattern = jnp.where( moving_left, -jnp.flip(current_pattern), current_pattern ) # check if this should be a submarine or a shark is_sub = jnp.logical_and(left_over, jnp.logical_not(spawn_state.prev_sub[i])) # set the positions for the first enemy in the wave (dependent on the direction this is either the first or the last slot) first_slot = jnp.where(moving_left, 0, 2) base_pos = self.get_spawn_position(moving_left, jnp.array(i)) # spawn the first enemy in the wave new_shark_positions = jnp.where( is_sub, shark_positions, shark_positions.at[(i * 3 + first_slot)].set(base_pos), ) new_sub_positions = jnp.where( is_sub, sub_positions.at[(i * 3 + first_slot)].set(base_pos), sub_positions ) # wipe the survived status for this lane (since we are starting a new wave) indices = jnp.array([i * 3, i * 3 + 1, i * 3 + 2]) new_survived_full = spawn_state.survived.at[indices].set( jnp.zeros(3, dtype=jnp.int32) ) # Set moving_left to the opposite of moving_left when determining which slot to clear in to_be_spawned new_to_be_spawned = current_pattern.at[jnp.where(moving_left, 0, 2)].set(0) # Update the full to_be_spawned array for this lane new_full_to_be_spawned = spawn_state.to_be_spawned.at[indices].set( new_to_be_spawned ) new_spawn_state = SpawnState( difficulty=spawn_state.difficulty, lane_dependent_pattern=spawn_state.lane_dependent_pattern.at[i].set( lane_specific_pattern ), to_be_spawned=new_full_to_be_spawned, survived=new_survived_full, prev_sub=spawn_state.prev_sub.at[i].set(is_sub), spawn_timers=spawn_state.spawn_timers.at[i].set(200), diver_array=spawn_state.diver_array, lane_directions=spawn_state.lane_directions, ) return new_spawn_state, new_shark_positions, new_sub_positions, diver_positions, rng # Modified continue_spawn_cycle to handle RNG def continue_spawn_cycle(i: int, carry): spawn_state, shark_positions, sub_positions, diver_positions, rng = carry # Rest of function remains the same, just pass along the RNG # get the relevant missing entities for this lane from the to_be_spawned array relevant_to_be_spawned = jax.lax.dynamic_slice( spawn_state.to_be_spawned, (i * 3,), (3,) ) # check in which direction we are moving by finding the first non-zero value in the missing_entities array moving_left = jnp.where( relevant_to_be_spawned[0] == 0, jnp.where( relevant_to_be_spawned[1] == 0, jnp.where(relevant_to_be_spawned[2] == -1, True, False), jnp.where(relevant_to_be_spawned[1] == -1, True, False), ), jnp.where(relevant_to_be_spawned[0] == -1, True, False), ) # Find the index of the first non-zero value based on direction def scan_right_to_left(j, val): return jnp.where(relevant_to_be_spawned[2 - j] != 0, 2 - j, val) def scan_left_to_right(j, val): return jnp.where(relevant_to_be_spawned[j] != 0, j, val) # Use fori_loop to scan array in appropriate direction spawn_idx = jax.lax.cond( moving_left, lambda _: jax.lax.fori_loop(0, 3, scan_left_to_right, -1), lambda _: jax.lax.fori_loop(0, 3, scan_right_to_left, -1), operand=None, ) spawn_idx = spawn_idx.astype(jnp.int32) # Get reference x position from neighboring entity # For moving right, look at entity to the right (spawn_idx + 1) # For moving left, look at entity to the left (spawn_idx - 1) reference_idx = jnp.where(moving_left, spawn_idx - 1, spawn_idx + 1) reference_idx = reference_idx.astype(jnp.int32) base_idx = i * 3 # Base index for this lane's entities # Get position from either shark or sub position arrays # We'll need to check both since we don't know which type exists reference_shark_pos = shark_positions[base_idx + reference_idx] reference_sub_pos = sub_positions[base_idx + reference_idx] # Use whichever position is non-zero (active) reference_x = jnp.where( reference_shark_pos[0] != 0, reference_shark_pos[0], reference_sub_pos[0] ) edge_case = reference_x == 0 # Edge Case: third option exists for the pattern 1 0 1, then check the next entity edge_case_reference_idx = jnp.where(moving_left, spawn_idx - 2, spawn_idx + 2) edge_case_reference_idx = edge_case_reference_idx.astype(jnp.int32) reference_x = jnp.where( edge_case, jnp.where( shark_positions[base_idx + edge_case_reference_idx][0] != 0, shark_positions[base_idx + edge_case_reference_idx][0], sub_positions[base_idx + edge_case_reference_idx][0], ), reference_x, ) # Get base spawn position for this lane base_spawn_pos = self.get_spawn_position(moving_left, jnp.array(i)) # check if the base spawn position x is 16 / 32 pixels away from the reference x position (depending on the edge case pattern) # if yes, spawn the entity, if no, do nothing offset = jnp.where(edge_case, 32, 16) should_spawn = jnp.abs(base_spawn_pos[0] - reference_x) >= offset # in case reference_x is still 0 (happens in case the player destroyed the first entity in the wave), we just instantly spawn the entity should_spawn = jnp.where(reference_x == 0, True, should_spawn) spawn_pos = jnp.where(should_spawn, base_spawn_pos, jnp.zeros(3)) # Update positions based on enemy type new_shark_positions = shark_positions.at[base_idx + spawn_idx].set( jnp.where( jnp.logical_not(spawn_state.prev_sub[i]), spawn_pos, shark_positions[base_idx + spawn_idx], ) ) new_sub_positions = sub_positions.at[base_idx + spawn_idx].set( jnp.where( spawn_state.prev_sub[i], spawn_pos, sub_positions[base_idx + spawn_idx] ) ) # Update the to_be_spawned array new_to_be_spawned = spawn_state.to_be_spawned.at[base_idx + spawn_idx].set( jnp.where( should_spawn, jnp.array(0), # Single value spawn_state.to_be_spawned[base_idx + spawn_idx], ) ) # Then create the new spawn state with the updated array new_spawn_state = SpawnState( difficulty=spawn_state.difficulty, lane_dependent_pattern=spawn_state.lane_dependent_pattern, to_be_spawned=new_to_be_spawned, survived=spawn_state.survived, prev_sub=spawn_state.prev_sub, spawn_timers=spawn_state.spawn_timers, diver_array=spawn_state.diver_array, lane_directions=spawn_state.lane_directions, ) return new_spawn_state, new_shark_positions, new_sub_positions, diver_positions, rng # Modified process_lane to handle RNG def process_lane(i, carry): loc_spawn_state, shark_positions, sub_positions, diver_positions, rng = carry base_idx = i * 3 # Base index for this lane's slots # determine if we need to initialize a new pattern or keep spawning for the current one # do this by checking in the relevant part of the to_be_spawned array if there are still 1s relevant_to_be_spawned = jax.lax.dynamic_slice( spawn_state.to_be_spawned, (base_idx,), (3,) ) # if there are still 1s in the relevant part of the to_be_spawned array, keep spawning keep_spawning = jnp.any(relevant_to_be_spawned) # check the lane spawn timer lane_timer = spawn_state.spawn_timers[i] base_idx = i * 3 # Get the sharks and subs for the current lane `i` lane_sharks = jax.lax.dynamic_slice(shark_positions, (base_idx, 0), (3, 3)) lane_subs = jax.lax.dynamic_slice(sub_positions, (base_idx, 0), (3, 3)) # Vectorized check for active entities in the lane sharks_active = lane_sharks[:, 2] != 0 subs_active = lane_subs[:, 2] != 0 lane_empty = jnp.all(~sharks_active & ~subs_active) # if the lane timer is unequal to 0, continue_spawn_cycle may still be called but initialize_new_spawn_cycle should not be called allow_new_initialization = jnp.logical_and(lane_timer == 0, lane_empty) def handle_no_spawning(x): spawn_state, shark_positions, sub_positions, diver_positions, rng = x return jax.lax.cond( allow_new_initialization, lambda y: initialize_new_spawn_cycle(i, y), lambda y: (y[0], y[1], y[2], y[3], y[4]), # Return unchanged state (spawn_state, shark_positions, sub_positions, diver_positions, rng), ) new_spawn_state, new_shark_positions, new_sub_positions, new_diver_positions, new_rng = jax.lax.cond( keep_spawning, lambda x: continue_spawn_cycle(i, x), handle_no_spawning, (loc_spawn_state, shark_positions, sub_positions, diver_positions, rng), ) return new_spawn_state, new_shark_positions, new_sub_positions, new_diver_positions, new_rng # Modify lane_needs_update to work with the rest of the function def lane_needs_update(i, spawn_state, shark_positions, sub_positions): base_idx = i * 3 # Base index for this lane's slots # get how many entities in this lane are inactive lane_empty = jnp.all( jnp.array( [ jnp.logical_and( self.is_slot_empty(shark_positions[base_idx + j]), self.is_slot_empty(sub_positions[base_idx + j]), ) for j in range(3) ] ) ) # check if the to_be_spawned array has any 1s in the relevant part relevant_to_be_spawned = jax.lax.dynamic_slice( spawn_state.to_be_spawned, (base_idx,), (3,) ) return jnp.logical_or(lane_empty, jnp.any(relevant_to_be_spawned)) # Replace the manual loop with lax.scan lane_indices = jnp.arange(4) (final_state, final_shark_positions, final_sub_positions, final_diver_positions, final_rng), _ = jax.lax.scan( scan_lanes, (new_state, shark_positions, sub_positions, diver_positions, rng if rng is not None else jax.random.PRNGKey(42)), lane_indices ) return final_state, final_shark_positions, final_sub_positions, final_rng
[docs] @partial(jax.jit, static_argnums=(0,)) def step_enemy_movement( self, spawn_state: SpawnState, shark_positions: chex.Array, sub_positions: chex.Array, step_counter: chex.Array, rng: chex.PRNGKey, ) -> Tuple[chex.Array, chex.Array, SpawnState, chex.PRNGKey]: """Update enemy positions based on their patterns""" # Split RNG key for direction randomization rng, direction_rng = jax.random.split(rng) def get_shark_offset(step_counter): """Calculates the vertical sinusoidal-like offset for sharks.""" phase = step_counter // 4 cycle_position = phase % 32 raw_offset = jnp.where( cycle_position < 16, cycle_position // 2, 7 - (cycle_position - 16) // 2, ) return raw_offset - 4 def calculate_movement_speed(step_counter, difficulty): """ Calculates movement speed based on difficulty. This function is vectorized and uses jnp.select for efficient conditional logic. """ # Ensure difficulty is non-negative and wraps at 256 for consistent logic safe_difficulty = jnp.maximum(0, difficulty % 256) # --- Speed for difficulties 0-9 --- diff_lt_10 = safe_difficulty < 10 cycle_pos = step_counter % 12 # Movement probabilities for difficulties 0-9 should_move_patterns = jnp.array([ (cycle_pos % 3) == 0, # 33% (cycle_pos % 2) == 0, # 50% (cycle_pos % 3) != 2, # 67% (cycle_pos % 4) != 3, # 75% (cycle_pos % 6) != 5, # 83% cycle_pos != 11, # 92% ]) # Indices to select the correct pattern based on difficulty indices = jnp.array([0, 1, 1, 2, 2, 3, 3, 4, 4, 5]) should_move = should_move_patterns[indices[safe_difficulty]] speed_for_diff_0_9 = jnp.where(should_move, 1, 0) # --- Speed for difficulties 10+ --- diff_above_threshold = jnp.maximum(0, safe_difficulty - 10) base_speed = 1 + (diff_above_threshold // 16) position_in_tier = diff_above_threshold % 16 # Probabilities for gaining +1 speed within a tier higher_speed_patterns = jnp.array([ (step_counter % 16) == 0, # 6.25% (step_counter % 8) == 0, # 12.5% (step_counter % 4) == 0, # 25% (step_counter % 2) == 0, # 50% (step_counter % 4) != 0, # 75% (step_counter % 8) != 0, # 87.5% (step_counter % 16) != 0, # 93.75% ]) # Indices to select the correct probability pattern tier_indices = jnp.array([0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 6]) use_higher_speed = higher_speed_patterns[tier_indices[position_in_tier]] speed_for_diff_10_plus = jnp.where(use_higher_speed, base_speed + 1, base_speed) # Select speed based on difficulty bracket return jnp.where(diff_lt_10, speed_for_diff_0_9, speed_for_diff_10_plus) def move_single_enemy(pos, is_shark, difficulty, slot_idx, step_counter): """Moves a single enemy. This function will be vmapped.""" is_active = jnp.logical_not(self.is_slot_empty(pos)) movement_speed = calculate_movement_speed(step_counter, difficulty) velocity_x = pos[2] * movement_speed # pos[2] is direction (-1 or 1) # Use modulo 4 to map 8 virtual lanes to 4 physical Y-positions lane_idx = (slot_idx // 3) % 4 base_y = self.consts.SPAWN_POSITIONS_Y[lane_idx] y_offset = jnp.where( is_shark, get_shark_offset(step_counter), -self.consts.SUBMARINE_Y_OFFSET ) y_position = base_y + y_offset # Ensure all calculations are done with integer arithmetic new_x = pos[0] + velocity_x new_pos = jnp.array([new_x, y_position, pos[2]], dtype=pos.dtype) new_pos = jnp.where(is_active, new_pos, pos) out_of_bounds = jnp.logical_or(new_pos[0] <= -8, new_pos[0] >= 168) final_pos = jnp.where(out_of_bounds, jnp.zeros_like(pos), new_pos) return final_pos, out_of_bounds # 1. Combine sharks and subs into a single array for vectorized processing # Ensure both arrays have the same dtype before concatenation shark_positions_int = shark_positions.astype(jnp.int32) sub_positions_int = sub_positions.astype(jnp.int32) all_positions = jnp.concatenate([shark_positions_int, sub_positions_int], axis=0) is_shark_array = jnp.concatenate([jnp.ones(12, dtype=bool), jnp.zeros(12, dtype=bool)]) all_slot_indices = jnp.arange(24) # 2. Apply movement to all 24 enemies in parallel using vmap vmap_move = jax.vmap(move_single_enemy, in_axes=(0, 0, None, 0, None)) new_all_positions, enemies_survived_mask = vmap_move( all_positions, is_shark_array, spawn_state.difficulty, all_slot_indices, step_counter ) # 3. Handle lane-based logic by padding the 4-lane state to an 8-lane structure # and then merging the results back down. # Pad the original 4-lane state for comparison purposes old_survived_padded = jnp.pad(spawn_state.survived, (0, 12)) lane_directions_padded = jnp.pad(spawn_state.lane_directions, (0, 4)) # Perform logic in the temporary 8-lane structure num_lanes = 8 # 4 for sharks, 4 for subs survived_mask_lanes = enemies_survived_mask.reshape(num_lanes, 3) old_survived_lanes = old_survived_padded.reshape(num_lanes, 3) any_newly_survived_in_lane = jnp.any( jnp.logical_and(survived_mask_lanes, old_survived_lanes == 0), axis=1 ) random_directions = jax.random.bernoulli(direction_rng, 0.5, (num_lanes,)).astype(spawn_state.survived.dtype) * 2 - 1 temp_lane_directions = jnp.where( any_newly_survived_in_lane, random_directions, lane_directions_padded ) all_pos_lanes = all_positions.reshape(num_lanes, 3, 3) dir0 = all_pos_lanes[:, 0, 2] dir1 = all_pos_lanes[:, 1, 2] dir2 = all_pos_lanes[:, 2, 2] lane_base_direction = jnp.where(dir0 != 0, dir0, jnp.where(dir1 != 0, dir1, jnp.where(dir2 != 0, dir2, 1))) survived_direction_per_slot = jnp.repeat(lane_base_direction, 3, axis=0) temp_survived = jnp.where( enemies_survived_mask, survived_direction_per_slot, old_survived_padded ) temp_survived_lanes = temp_survived.reshape(num_lanes, 3) lanes_to_flip = (lane_base_direction == -1) flipped_survived = jnp.flip(temp_survived_lanes, axis=1) temp_survived_lanes = jnp.where( lanes_to_flip[:, None], # Expand dims for broadcasting flipped_survived, temp_survived_lanes ) # 4. Merge the 8-lane results back into the 4-lane state structure # `spawn_state.prev_sub` (shape 4,) tells us if a lane is sharks (0) or subs (1) is_sub_mask = spawn_state.prev_sub[:, None].astype(bool) # Shape: (4, 1) # Merge `survived` results shark_survived_res = temp_survived_lanes[:4] sub_survived_res = temp_survived_lanes[4:] final_survived_lanes = jnp.where(is_sub_mask, sub_survived_res, shark_survived_res) new_survived = final_survived_lanes.flatten() # Final shape: (12,) # Merge `lane_directions` shark_dir_res, sub_dir_res = temp_lane_directions[:4], temp_lane_directions[4:] new_lane_directions = jnp.where(spawn_state.prev_sub.astype(bool), sub_dir_res, shark_dir_res) # Final shape: (4,) # Merge updates for `diver_array` and `spawn_timers` any_newly_survived_sharks = any_newly_survived_in_lane[:4] any_newly_survived_subs = any_newly_survived_in_lane[4:] any_newly_survived_final = jnp.where(spawn_state.prev_sub.astype(bool), any_newly_survived_subs, any_newly_survived_sharks) new_diver_array = jnp.where( jnp.logical_and(any_newly_survived_final, spawn_state.diver_array == -1), 1, spawn_state.diver_array ) new_spawn_timers = jnp.where( any_newly_survived_final, 200, spawn_state.spawn_timers ) # 5. Update state with merged results and split positions back into sharks and subs new_spawn_state = spawn_state._replace( survived=new_survived, lane_directions=new_lane_directions, diver_array=new_diver_array, spawn_timers=new_spawn_timers ) new_shark_positions, new_sub_positions = jnp.split(new_all_positions, 2, axis=0) # Ensure the returned positions have the same dtype as the input positions new_shark_positions = new_shark_positions.astype(shark_positions.dtype) new_sub_positions = new_sub_positions.astype(sub_positions.dtype) return new_shark_positions, new_sub_positions, new_spawn_state, rng
[docs] @partial(jax.jit, static_argnums=(0,)) def spawn_divers( self, spawn_state: SpawnState, diver_positions: chex.Array, shark_positions: chex.Array, sub_positions: chex.Array, step_counter: chex.Array, ) -> tuple[chex.Array, SpawnState]: """ Vectorized function to spawn divers according to pattern that depends on collection state. """ # --- 1. Vectorized Pre-computation and Checks (for all 4 lanes at once) --- # Condition: Only process lanes where the spawn timer is at the trigger value. timers_ready_mask = spawn_state.spawn_timers == 60 # Shape: (4,) # Condition: A diver must not already exist in the lane. diver_exists_mask = diver_positions[:, 2] != 0 # Shape: (4,) # Condition: The enemy lane must be empty. # This replaces the slow list comprehension with vectorized operations. sharks_active_per_lane = jnp.any(shark_positions.reshape(4, 3, 3)[:, :, 2] != 0, axis=1) subs_active_per_lane = jnp.any(sub_positions.reshape(4, 3, 3)[:, :, 2] != 0, axis=1) lanes_are_empty_mask = jnp.logical_not( jnp.logical_or(sharks_active_per_lane, subs_active_per_lane) ) # Shape: (4,) # Condition: The lane must be marked as available for spawning (value of 1). lanes_ready_to_spawn_mask = spawn_state.diver_array == 1 # Shape: (4,) # Condition: Do not spawn a diver if the next enemy in that lane is a submarine. prev_was_not_sub = jnp.logical_not(spawn_state.prev_sub) something_survived = jnp.any(spawn_state.survived.reshape(4, 3) != 0, axis=1) next_is_sub_mask = jnp.logical_and(prev_was_not_sub, something_survived) # Override: if the previous enemy was a sub, the next cannot be a sub. next_is_sub_mask = jnp.where(spawn_state.prev_sub, False, next_is_sub_mask) # --- 2. Combine All Conditions to Get Final Spawn Mask --- # A diver should spawn only if ALL conditions for its lane are met. should_spawn_mask = jnp.logical_and.reduce( jnp.array([ timers_ready_mask, jnp.logical_not(diver_exists_mask), lanes_are_empty_mask, lanes_ready_to_spawn_mask, jnp.logical_not(next_is_sub_mask), ]) ) # Shape: (4,) # --- 3. Calculate New Positions and State --- # Calculate spawn positions for all lanes based on their direction. moving_left_mask = spawn_state.lane_directions == 1 x_positions = jnp.where(moving_left_mask, 168, 0) directions = jnp.where(moving_left_mask, -1, 1) # Create the potential new diver data for all 4 lanes. potential_new_divers = jnp.stack( [x_positions, self.consts.DIVER_SPAWN_POSITIONS, directions], axis=1 ) # Shape: (4, 3) # Use the final spawn mask to decide whether to use the new diver data or keep the old. # The `[:, None]` broadcasts the (4,) mask to the (4, 3) diver_positions array. new_diver_positions = jnp.where( should_spawn_mask[:, None], potential_new_divers, diver_positions ) # Update the diver_array: If a lane was marked with -1 (diver swam off-screen) # and the lane is now empty, mark it as ready for a new spawn cycle (value 1). spawn_next_cycle_mask = jnp.logical_and(spawn_state.diver_array == -1, lanes_are_empty_mask) new_diver_array = jnp.where( spawn_next_cycle_mask, 1, spawn_state.diver_array ) return new_diver_positions, spawn_state._replace(diver_array=new_diver_array)
[docs] @partial(jax.jit, static_argnums=(0,)) def step_diver_movement( self, diver_positions: chex.Array, shark_positions: chex.Array, state_player_x: chex.Array, state_player_y: chex.Array, state_divers_collected: chex.Array, spawn_state: SpawnState, step_counter: chex.Array, rng: chex.PRNGKey, ) -> tuple[chex.Array, chex.Array, SpawnState, chex.PRNGKey]: """Move divers according to their pattern and handle collisions. Returns updated diver positions, number of collected divers, updated spawn state, and updated RNG key. """ new_diver_array = spawn_state.diver_array def calculate_diver_movement(step_counter, difficulty): """Calculate diver movement based on difficulty level. Args: step_counter: Current step counter (frame number) difficulty: Current difficulty level (0-255) Returns: Movement speed for the current frame (0, 1, or 2+) 0 = no movement, 1 = normal speed, 2+ = higher speeds """ # Ensure difficulty is non-negative and handle wrapping safe_difficulty = jnp.clip(difficulty % 256, 0, 255) # For difficulties 0-27, we have specific movement patterns is_high_difficulty = safe_difficulty >= 28 # For difficulties 0-27, determine if we should move and use speed 1 low_diff_should_move = determine_low_difficulty_movement( step_counter, safe_difficulty ) low_diff_speed = jnp.where(low_diff_should_move, 1, 0) # For difficulties 28+, always move but with varying speed high_diff_speed = determine_high_difficulty_speed(step_counter, safe_difficulty) # Return appropriate speed based on difficulty return jnp.where(is_high_difficulty, high_diff_speed, low_diff_speed) def determine_low_difficulty_movement(step_counter, difficulty): """Determine if the diver should move for difficulties 0-27.""" # Create boolean masks for each difficulty bracket diff_0_1 = jnp.logical_and(difficulty >= 0, difficulty <= 1) diff_2_3 = jnp.logical_and(difficulty >= 2, difficulty <= 3) diff_4_5 = jnp.logical_and(difficulty >= 4, difficulty <= 5) diff_6_7 = jnp.logical_and(difficulty >= 6, difficulty <= 7) diff_8_9 = jnp.logical_and(difficulty >= 8, difficulty <= 9) diff_10_11 = jnp.logical_and(difficulty >= 10, difficulty <= 11) diff_12_13 = jnp.logical_and(difficulty >= 12, difficulty <= 13) diff_14_15 = jnp.logical_and(difficulty >= 14, difficulty <= 15) diff_16_17 = jnp.logical_and(difficulty >= 16, difficulty <= 17) diff_18_19 = jnp.logical_and(difficulty >= 18, difficulty <= 19) diff_20_21 = jnp.logical_and(difficulty >= 20, difficulty <= 21) diff_22_23 = jnp.logical_and(difficulty >= 22, difficulty <= 23) diff_24_25 = jnp.logical_and(difficulty >= 24, difficulty <= 25) diff_26_27 = jnp.logical_and(difficulty >= 26, difficulty <= 27) # Movement patterns for each bracket based on paste.txt # Difficulty 0-1: Move every 5th frame (20% movement) move_0_1 = (step_counter % 5) == 0 # Difficulty 2-3: Move every 4th frame (25% movement) move_2_3 = (step_counter % 4) == 0 # Difficulty 4-5: Move every 3rd frame (33.3% movement) move_4_5 = (step_counter % 3) == 0 # Difficulty 6-7: Move in pattern [1,0,0,1,0,1,0,0] (37.5% movement) cycle_6_7 = step_counter % 8 move_6_7 = jnp.logical_or( cycle_6_7 == 0, jnp.logical_or(cycle_6_7 == 3, cycle_6_7 == 5) ) # Difficulty 8-9: Move in pattern [1,0,1,0,1,0,0,1,0,1] (50% movement) cycle_8_9 = step_counter % 10 move_8_9 = jnp.logical_or( jnp.logical_or(cycle_8_9 == 0, cycle_8_9 == 2), jnp.logical_or(cycle_8_9 == 4, cycle_8_9 == 7), ) move_8_9 = jnp.logical_or(move_8_9, cycle_8_9 == 9) # Difficulty 10-11: Move every other frame (50% movement) move_10_11 = (step_counter % 2) == 0 # Difficulty 12-13: Complex pattern with ~60% movement cycle_12_13 = step_counter % 8 move_12_13 = jnp.logical_or( jnp.logical_or(cycle_12_13 == 0, cycle_12_13 == 2), jnp.logical_or(cycle_12_13 == 4, cycle_12_13 == 6), ) move_12_13 = jnp.logical_or(move_12_13, cycle_12_13 == 7) # Difficulty 14-15: Complex pattern with ~65% movement cycle_14_15 = step_counter % 7 move_14_15 = jnp.logical_or( jnp.logical_or(cycle_14_15 == 0, cycle_14_15 == 1), jnp.logical_or(cycle_14_15 == 3, cycle_14_15 == 5), ) move_14_15 = jnp.logical_or(move_14_15, cycle_14_15 == 6) # Difficulty 16-17: Complex pattern with ~70% movement cycle_16_17 = step_counter % 10 move_16_17 = jnp.logical_or( jnp.logical_or(cycle_16_17 == 0, cycle_16_17 == 1), jnp.logical_or(cycle_16_17 == 3, cycle_16_17 == 4), ) move_16_17 = jnp.logical_or( move_16_17, jnp.logical_or(cycle_16_17 == 6, cycle_16_17 == 8) ) move_16_17 = jnp.logical_or(move_16_17, cycle_16_17 == 9) # Difficulty 18-19: Move 3 out of 4 frames (75% movement) move_18_19 = (step_counter % 4) != 3 # Difficulty 20-21: Move 4 out of 5 frames (80% movement) move_20_21 = (step_counter % 5) != 4 # Difficulty 22-23: Move 7 out of 8 frames (87.5% movement) move_22_23 = (step_counter % 8) != 7 # Difficulty 24-25: Move 15 out of 16 frames (93.75% movement) move_24_25 = (step_counter % 16) != 15 # Difficulty 26-27: Always move (100% movement) move_26_27 = True # Combine all patterns using jnp.select which is cleaner for many conditions # Create condition array - only the first True condition will be used conditions = jnp.array( [ diff_0_1, diff_2_3, diff_4_5, diff_6_7, diff_8_9, diff_10_11, diff_12_13, diff_14_15, diff_16_17, diff_18_19, diff_20_21, diff_22_23, diff_24_25, diff_26_27, ] ) # Create corresponding values array values = jnp.array( [ move_0_1, move_2_3, move_4_5, move_6_7, move_8_9, move_10_11, move_12_13, move_14_15, move_16_17, move_18_19, move_20_21, move_22_23, move_24_25, move_26_27, ] ) # Select the appropriate pattern based on which condition is True should_move = jnp.select(conditions, values, default=False) return should_move def determine_high_difficulty_speed(step_counter, difficulty): """Determine the speed (1 or 2+) for difficulties 28+.""" # Adjust difficulty to start from 0 for easier tier calculations diff_above_27 = difficulty - 28 # Each 16 difficulty levels form a tier (just like in shark/submarine algorithm) tier = diff_above_27 // 16 position_in_tier = diff_above_27 % 16 # Base speed for each tier (increases by 1 for each tier) base_speed = tier + 1 higher_speed = tier + 2 # Position brackets within tier (matches the pattern observed in paste.txt) pos_0 = position_in_tier == 0 pos_1_3 = jnp.logical_and(position_in_tier >= 1, position_in_tier <= 3) pos_4_6 = jnp.logical_and(position_in_tier >= 4, position_in_tier <= 6) pos_7_9 = jnp.logical_and(position_in_tier >= 7, position_in_tier <= 9) pos_10_12 = jnp.logical_and(position_in_tier >= 10, position_in_tier <= 12) pos_13_14 = jnp.logical_and(position_in_tier >= 13, position_in_tier <= 14) pos_15 = position_in_tier == 15 # Determine higher speed frequency based on position in tier # These frequencies match the observed patterns in paste.txt use_higher_speed_pos_0 = (step_counter % 16) == 15 # 1 in 16 frames (6.25%) use_higher_speed_pos_1_3 = (step_counter % 8) == 7 # 1 in 8 frames (12.5%) use_higher_speed_pos_4_6 = (step_counter % 4) == 3 # 1 in 4 frames (25%) use_higher_speed_pos_7_9 = (step_counter % 2) == 1 # 1 in 2 frames (50%) use_higher_speed_pos_10_12 = (step_counter % 4) != 0 # 3 in 4 frames (75%) use_higher_speed_pos_13_14 = (step_counter % 8) != 0 # 7 in 8 frames (87.5%) use_higher_speed_pos_15 = (step_counter % 16) != 0 # 15 in 16 frames (93.75%) # Select the appropriate higher speed frequency based on position # Use jnp.select for cleaner code with multiple conditions position_conditions = jnp.array( [pos_0, pos_1_3, pos_4_6, pos_7_9, pos_10_12, pos_13_14, pos_15] ) speed_values = jnp.array( [ use_higher_speed_pos_0, use_higher_speed_pos_1_3, use_higher_speed_pos_4_6, use_higher_speed_pos_7_9, use_higher_speed_pos_10_12, use_higher_speed_pos_13_14, use_higher_speed_pos_15, ] ) use_higher_speed = jnp.select(position_conditions, speed_values, default=False) # Calculate final speed: higher_speed or base_speed return jnp.where(use_higher_speed, higher_speed, base_speed) def move_single_diver(i, carry): # Unpack carry state - (positions, collected_count, diver_array) positions, collected, diver_array = carry diver_pos = positions[i] # Only process active divers (direction != 0) is_active = diver_pos[2] != 0 # Check for collision with player first if diver is active player_collision = jnp.logical_and( is_active, self.check_collision_single( jnp.array([state_player_x, state_player_y]), self.consts.PLAYER_SIZE, jnp.array([diver_pos[0], diver_pos[1]]), self.consts.DIVER_SIZE, ), ) # Only collect if we haven't reached max divers can_collect = state_divers_collected < 6 should_collect = jnp.logical_and(player_collision, can_collect) # Get the three sharks in the lane all_shark_lane_pos = jax.lax.dynamic_slice(shark_positions, (i * 3, 0), (3, 3)) # Get shark in the same lane for collision check shark_lane_pos = self.get_front_entity(i, all_shark_lane_pos) shark_collision = jnp.logical_and( is_active, self.check_collision_single( jnp.array([shark_lane_pos[0], shark_lane_pos[1]]), self.consts.SHARK_SIZE, jnp.array([diver_pos[0], diver_pos[1]]), self.consts.DIVER_SIZE, ), ) # check in which direction the shark is moving and copy the direction to the diver direction_of_shark = jnp.where( shark_lane_pos[2] == 0, diver_pos[2], shark_lane_pos[2] ) # Calculate movement based on difficulty movement_speed = calculate_diver_movement(step_counter, spawn_state.difficulty) should_move = movement_speed > 0 # Calculate movement direction (with speed factor) # If colliding with shark, use shark's direction/speed # Otherwise use diver's direction with appropriate speed factor movement_x = jnp.where( shark_collision, shark_lane_pos[2], # Use shark's direction/speed diver_pos[2] * movement_speed, # Apply difficulty-based speed ) # Calculate new position new_x = jnp.where( shark_collision, diver_pos[0] + movement_x, # Move with shark jnp.where( should_move, diver_pos[0] + movement_x, # Move with calculated speed diver_pos[0], # Stay still ), ) # Check bounds out_of_bounds = jnp.logical_or(new_x <= -8, new_x >= 170) # Create new position array - handle collection and bounds new_pos = jnp.where( jnp.logical_or(~is_active, jnp.logical_or(out_of_bounds, should_collect)), jnp.zeros(3), # Reset if out of bounds or collected jnp.array([new_x, self.consts.DIVER_SPAWN_POSITIONS[i], direction_of_shark]), ) # Update collection count if collected new_collected = collected + jnp.where(should_collect, 1, 0) # Update diver collection tracking - mark lane as collected when diver is collected updated_diver_array = diver_array.at[i].set( jnp.where(should_collect, 0, diver_array[i]) ) # if the diver went out of bounds set the entry to -1 updated_diver_array = updated_diver_array.at[i].set( jnp.where(out_of_bounds, -1, updated_diver_array[i]) ) # Update the diver position, collection count and diver_array return positions.at[i].set(new_pos), new_collected, updated_diver_array # Update all diver positions and track collections initial_carry = (diver_positions, state_divers_collected, new_diver_array) final_positions, final_collected, final_diver_array = jax.lax.fori_loop( 0, diver_positions.shape[0], move_single_diver, initial_carry ) # Handle case where all divers are collected - set all lanes to -1 # Apply the reset only if all divers have been collected reset_array = jnp.where( jnp.all(final_diver_array == 0), jnp.array([-1, -1, -1, -1], dtype=jnp.int32), # Randomized reset array final_diver_array, # Otherwise keep current state ) # Create updated spawn state updated_spawn_state = spawn_state._replace(diver_array=reset_array) return final_positions, final_collected, updated_spawn_state, rng
[docs] @partial(jax.jit, static_argnums=(0,)) def spawn_step( self, state, spawn_state: SpawnState, shark_positions: chex.Array, sub_positions: chex.Array, diver_positions: chex.Array, rng_key: chex.PRNGKey, ) -> Tuple[SpawnState, chex.Array, chex.Array, chex.Array, chex.Array]: """Main spawn handling function to be called in game step""" # Move existing enemies new_shark_positions, new_sub_positions, spawn_state_after_movement, new_key = ( self.step_enemy_movement( spawn_state, shark_positions, sub_positions, state.step_counter, rng_key ) ) # Update spawns using updated spawn state new_spawn_state, new_shark_positions, new_sub_positions, new_key = ( self.update_enemy_spawns( spawn_state_after_movement, new_shark_positions, new_sub_positions, diver_positions, state.step_counter, new_key, ) ) # Spawn new divers with updated tracking new_diver_positions, final_spawn_state = self.spawn_divers( new_spawn_state, diver_positions, new_shark_positions, new_sub_positions, state.step_counter, ) return ( final_spawn_state, new_shark_positions, new_sub_positions, new_diver_positions, new_key, )
[docs] def surface_sub_step(self, state: SeaquestState) -> chex.Array: # Check direction value specifically to get scalar boolean sub_exists = state.surface_sub_position[2] != 0 def spawn_sub(_): return jnp.array([159, 45, -1]) # Always spawns right facing left def move_sub(carry): sub_pos = carry new_x = jnp.where( state.step_counter % 4 == 0, sub_pos[0] - 1, # Direction always -1 sub_pos[0], ) # Return either zeros or new position return jnp.where( jnp.logical_or(new_x < -8, sub_pos[2] == 0), jnp.zeros(3), jnp.array([new_x, 45, -1]), ) # Each condition needs to be scalar enough_rescues = state.successful_rescues >= 2 enough_divers = state.divers_collected >= 1 correct_timing = jnp.logical_and( state.step_counter % 256 == 0, state.step_counter != 0 ) # check if the submarine should spawn should_spawn = jnp.logical_and( jnp.logical_and(enough_rescues, enough_divers), jnp.logical_and(correct_timing, ~sub_exists), ) temp1 = spawn_sub(state.surface_sub_position) temp2 = move_sub(state.surface_sub_position) return jnp.where(should_spawn, temp1, temp2)
[docs] @partial(jax.jit, static_argnums=(0,)) def enemy_missiles_step( self, curr_sub_positions, curr_enemy_missile_positions, step_counter, difficulty ) -> chex.Array: def calculate_missile_speed(step_counter, difficulty): """JAX-compatible missile speed calculation function""" # Base tier size is 16 difficulty levels tier_size = 16 # Determine base speed (1, 2, 3, etc.) based on difficulty tier base_speed = 1 + (difficulty // tier_size) # Calculate position within the current tier (0-15) position_in_tier = difficulty % tier_size # Special case for difficulty 0 is_diff_0 = difficulty == 0 # Create position bracket array for each pattern pos_brackets = jnp.array( [ jnp.logical_and( position_in_tier >= 0, position_in_tier <= 2 ), # 0-2: 6.25% jnp.logical_and( position_in_tier >= 3, position_in_tier <= 4 ), # 3-4: 12.5% jnp.logical_and( position_in_tier >= 5, position_in_tier <= 6 ), # 5-6: 25% jnp.logical_and( position_in_tier >= 7, position_in_tier <= 8 ), # 7-8: 50% jnp.logical_and( position_in_tier >= 9, position_in_tier <= 10 ), # 9-10: 75% jnp.logical_and( position_in_tier >= 11, position_in_tier <= 12 ), # 11-12: 87.5% jnp.logical_and( position_in_tier >= 13, position_in_tier <= 14 ), # 13-14: 93.75% position_in_tier == 15, # 15: 100% ] ) # Create array of higher speed patterns higher_speed_patterns = jnp.array( [ (step_counter % 16) == 0, # 6.25% (step_counter % 8) == 0, # 12.5% (step_counter % 4) == 0, # 25% (step_counter % 2) == 0, # 50% (step_counter % 4) != 0, # 75% (step_counter % 8) != 0, # 87.5% (step_counter % 16) != 0, # 93.75% True, # 100% ] ) # Use jnp.select to choose the pattern use_higher_speed = jnp.select( pos_brackets, higher_speed_patterns, default=False ) # Higher speed is base_speed + 1 higher_speed = base_speed + 1 # Handle difficulty 0 special case return jnp.where( is_diff_0, 1, jnp.where(use_higher_speed, higher_speed, base_speed) ) # 1. Define a function that operates on a SINGLE missile and its corresponding lane. # It no longer needs an index `i` or a `carry` argument. def vmapped_missile_update(missile_pos, lane_subs, lane_y_pos): # Get the front submarine for this specific lane sub_pos = self.get_front_entity(0, lane_subs) # Index 0 is fine since it only looks at the 3 subs passed in # Check if the missile should be spawned missile_exists = missile_pos[2] != 0 should_spawn = jnp.logical_and( ~missile_exists, (sub_pos[0] >= self.consts.MISSILE_SPAWN_POSITIONS[0]) & (sub_pos[0] <= self.consts.MISSILE_SPAWN_POSITIONS[1]) ) # Calculate new missile position new_missile_x = sub_pos[0] + 4 * sub_pos[2] spawned_missile = jnp.array([new_missile_x, lane_y_pos, sub_pos[2]]) new_missile = jnp.where(should_spawn, spawned_missile, missile_pos) # Move the missile if it exists movement_speed = calculate_missile_speed(step_counter, difficulty) velocity = movement_speed * new_missile[2] moved_missile = new_missile.at[0].add(velocity) new_missile = jnp.where(missile_exists, moved_missile, new_missile) # Check bounds and return is_out_of_bounds = (new_missile[0] < self.consts.X_BORDERS[0]) | (new_missile[0] > self.consts.X_BORDERS[1]) return jnp.where(is_out_of_bounds, jnp.zeros(3), new_missile) # 2. Prepare the inputs for vmap # Reshape subs into a per-lane format: (4 lanes, 3 subs per lane, 3 coords) all_lane_subs = curr_sub_positions.reshape(4, 3, 3) # 3. Use jax.vmap to apply the update function in parallel new_missile_positions = jax.vmap( vmapped_missile_update, in_axes=(0, 0, 0) # Map over missiles, sub-lanes, and y-positions )(curr_enemy_missile_positions, all_lane_subs, self.consts.ENEMY_MISSILE_Y) return new_missile_positions
[docs] @partial(jax.jit, static_argnums=(0,)) def player_missile_step( self, state: SeaquestState, curr_player_x, curr_player_y, action: chex.Array ) -> chex.Array: # check if the player shot this frame fire = jnp.any( jnp.array( [ action == Action.FIRE, action == Action.UPRIGHTFIRE, action == Action.UPLEFTFIRE, action == Action.DOWNFIRE, action == Action.DOWNRIGHTFIRE, action == Action.DOWNLEFTFIRE, action == Action.RIGHTFIRE, action == Action.LEFTFIRE, action == Action.UPFIRE, ] ) ) # IMPORTANT: do not change the order of this check, since the missile does not move in its first frame!! # also check if there is currently a missile in frame by checking if the player_missile_position is empty missile_exists = state.player_missile_position[2] != 0 # if the player shot and there is no missile in frame, then we can shoot a missile # the missile y is the current player y position + 7 # the missile x is either player x + 3 if facing left or player x + 13 if facing right new_missile = jnp.where( jnp.logical_and(fire, jnp.logical_not(missile_exists)), jnp.where( state.player_direction == -1, jnp.array([curr_player_x + 3, curr_player_y + 7, -1]), jnp.array([curr_player_x + 13, curr_player_y + 7, 1]), ), state.player_missile_position, ) # if a missile is in frame and exists, we move the missile further in the specified direction (5 per tick), also always put the missile at the current player y position new_missile = jnp.where( missile_exists, jnp.array( [new_missile[0] + new_missile[2] * 5, curr_player_y + 7, new_missile[2]] ), new_missile, ) # check if the new positions are still in bounds new_missile = jnp.where( new_missile[0] < self.consts.X_BORDERS[0], jnp.array([0, 0, 0]), jnp.where(new_missile[0] > self.consts.X_BORDERS[1], jnp.array([0, 0, 0]), new_missile), ) return new_missile
[docs] @partial(jax.jit, static_argnums=(0,)) def update_oxygen(self, state, player_x, player_y, player_missile_position): """Update oxygen levels and handle surfacing mechanics with proper surfacing detection""" PLAYER_BREATHING_Y = [47, 52] # Range where oxygen neither increases nor decreases # Detect actual surfacing moment at_surface = player_y == 46 was_underwater = player_y > 46 just_surfaced = jnp.logical_and(at_surface, state.just_surfaced == 0) # Check player state decrease_ox = player_y > PLAYER_BREATHING_Y[1] has_divers = state.divers_collected >= 0 # Changed to > 0 instead of >= 0 has_all_divers = state.divers_collected >= 6 needs_oxygen = state.oxygen < 64 # Special handling for initialization state in_init_state = state.just_surfaced == -1 started_diving = player_y > self.consts.PLAYER_START_Y filling_init_oxygen = jnp.logical_and(in_init_state, state.oxygen < 64) # Surfacing conditions increase_ox = jnp.logical_and(at_surface, needs_oxygen) stay_same = jnp.logical_and( player_y >= PLAYER_BREATHING_Y[0], player_y <= PLAYER_BREATHING_Y[1] ) # Calculate new divers count before other logic new_divers_collected = jnp.where( jnp.logical_and(just_surfaced, has_divers), jnp.where(in_init_state, state.divers_collected, state.divers_collected - 1), state.divers_collected, ) # Handle surfacing without divers - prevent during init # Only lose life if we started with no divers lose_life = jnp.logical_and( jnp.logical_and(just_surfaced, new_divers_collected < 0), jnp.logical_not(in_init_state), ) # Handle surfacing with all divers should_reset = jnp.logical_and(just_surfaced, has_all_divers) # Update surfacing flag with consideration for remaining divers new_just_surfaced = jnp.where( in_init_state, jnp.where( jnp.logical_and(started_diving, state.oxygen >= 63), jnp.array(0), jnp.array(-1), ), jnp.where( was_underwater, jnp.array(0), jnp.where(at_surface, jnp.array(1), state.just_surfaced), ), ) # Handle oxygen changes new_oxygen = jnp.where( filling_init_oxygen, jnp.where(state.step_counter % 2 == 0, state.oxygen + 1, state.oxygen), jnp.where( decrease_ox, jnp.where(state.step_counter % 32 == 0, state.oxygen - 1, state.oxygen), state.oxygen, ), ) # Important: Base blocking decision on has_divers instead of still_has_divers can_refill = jnp.logical_and(increase_ox, has_divers) new_oxygen = jnp.where( jnp.logical_and(can_refill, jnp.logical_not(in_init_state)), jnp.where( state.oxygen < 64, jnp.where(state.step_counter % 2 == 0, state.oxygen + 1, state.oxygen), state.oxygen, ), new_oxygen, ) # Increase difficulty when reaching max oxygen after surfacing old_difficulty = state.spawn_state.difficulty reached_max = jnp.logical_and( jnp.logical_and(new_oxygen >= 64, state.oxygen < 64), jnp.logical_not(in_init_state), ) new_difficulty = jnp.where(reached_max, old_difficulty + 1, old_difficulty) new_oxygen = jnp.where(stay_same, state.oxygen, new_oxygen) # Use has_divers for blocking decision and combine with oxygen check should_block = jnp.logical_and(at_surface, needs_oxygen) player_x = jnp.where(should_block, state.player_x, player_x) player_y = jnp.where( should_block, jnp.array(46, dtype=jnp.int32), # Force to exact surface position player_y, ) player_missile_position = jnp.where( should_block, jnp.zeros(3), player_missile_position ) # Prevent oxygen depletion during init oxygen_depleted = jnp.logical_and( new_oxygen <= jnp.array(0), jnp.logical_not(in_init_state) ) return ( new_oxygen, player_x, player_y, player_missile_position, oxygen_depleted, lose_life, new_divers_collected, should_reset, new_just_surfaced, new_difficulty, )
[docs] @partial(jax.jit, static_argnums=(0,)) def player_step( self, state: SeaquestState, action: chex.Array ) -> tuple[chex.Array, chex.Array, chex.Array]: # implement all the possible movement directions for the player, the mapping is: # anything with left in it, add -1 to the x position # anything with right in it, add 1 to the x position # anything with up in it, add -1 to the y position # anything with down in it, add 1 to the y position up = jnp.any( jnp.array( [ action == Action.UP, action == Action.UPRIGHT, action == Action.UPLEFT, action == Action.UPFIRE, action == Action.UPRIGHTFIRE, action == Action.UPLEFTFIRE, ] ) ) down = jnp.any( jnp.array( [ action == Action.DOWN, action == Action.DOWNRIGHT, action == Action.DOWNLEFT, action == Action.DOWNFIRE, action == Action.DOWNRIGHTFIRE, action == Action.DOWNLEFTFIRE, ] ) ) left = jnp.any( jnp.array( [ action == Action.LEFT, action == Action.UPLEFT, action == Action.DOWNLEFT, action == Action.LEFTFIRE, action == Action.UPLEFTFIRE, action == Action.DOWNLEFTFIRE, ] ) ) right = jnp.any( jnp.array( [ action == Action.RIGHT, action == Action.UPRIGHT, action == Action.DOWNRIGHT, action == Action.RIGHTFIRE, action == Action.UPRIGHTFIRE, action == Action.DOWNRIGHTFIRE, ] ) ) player_x = jnp.where( right, state.player_x + 1, jnp.where(left, state.player_x - 1, state.player_x) ) player_y = jnp.where( down, state.player_y + 1, jnp.where(up, state.player_y - 1, state.player_y) ) # set the direction according to the movement player_direction = jnp.where(right, 1, jnp.where(left, -1, state.player_direction)) # perform out of bounds checks player_x = jnp.where( player_x < self.consts.PLAYER_BOUNDS[0][0], self.consts.PLAYER_BOUNDS[0][0], # Clamp to min player bound jnp.where( player_x > self.consts.PLAYER_BOUNDS[0][1], self.consts.PLAYER_BOUNDS[0][1], # Clamp to max player bound player_x, ), ) player_y = jnp.where( player_y < self.consts.PLAYER_BOUNDS[1][0], self.consts.PLAYER_BOUNDS[1][0], jnp.where(player_y > self.consts.PLAYER_BOUNDS[1][1], self.consts.PLAYER_BOUNDS[1][1], player_y), ) return player_x, player_y, player_direction
[docs] @partial(jax.jit, static_argnums=(0,)) def calculate_kill_points(self, successful_rescues: chex.Array) -> chex.Array: """Calculate the points awarded for killing a shark or submarine. Sharks and submarines are worth 20 points. The points are increased by 10 for each successful rescue with a maximum of 90.""" base_points = 20 max_points = 90 additional_points = 10 * successful_rescues return jnp.minimum(base_points + additional_points, max_points)
# Minimal ALE action set for Seaquest (from scripts/action_space_helper.py) ACTION_SET: jnp.ndarray = jnp.array( [ Action.NOOP, Action.FIRE, Action.UP, Action.RIGHT, Action.LEFT, Action.DOWN, Action.UPRIGHT, Action.UPLEFT, Action.DOWNRIGHT, Action.DOWNLEFT, Action.UPFIRE, Action.RIGHTFIRE, Action.LEFTFIRE, Action.DOWNFIRE, Action.UPRIGHTFIRE, Action.UPLEFTFIRE, Action.DOWNRIGHTFIRE, Action.DOWNLEFTFIRE, ], dtype=jnp.int32, ) def __init__(self, consts: SeaquestConstants = None): consts = consts or SeaquestConstants() super().__init__(consts) self.obs_size = 6 + 12 * 5 + 12 * 5 + 4 * 5 + 4 * 5 + 5 + 5 + 4 self.renderer = SeaquestRenderer(self.consts)
[docs] @partial(jax.jit, static_argnums=(0,)) def render(self, state: SeaquestState) -> jnp.ndarray: """Render the game state to a raster image.""" return self.renderer.render(state)
[docs] def flatten_entity_position(self, entity: EntityPosition) -> jnp.ndarray: return jnp.concatenate([ jnp.array([entity.x], dtype=jnp.int32), jnp.array([entity.y], dtype=jnp.int32), jnp.array([entity.width], dtype=jnp.int32), jnp.array([entity.height], dtype=jnp.int32), jnp.array([entity.active], dtype=jnp.int32) ])
[docs] def flatten_player_entity(self, entity: PlayerEntity) -> jnp.ndarray: return jnp.concatenate([ jnp.array([entity.x], dtype=jnp.int32), jnp.array([entity.y], dtype=jnp.int32), jnp.array([entity.o], dtype=jnp.int32), jnp.array([entity.width], dtype=jnp.int32), jnp.array([entity.height], dtype=jnp.int32), jnp.array([entity.active], dtype=jnp.int32) ])
[docs] @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: SeaquestObservation) -> jnp.ndarray: return jnp.concatenate([ self.flatten_player_entity(obs.player), obs.sharks.flatten().astype(jnp.int32), obs.submarines.flatten().astype(jnp.int32), obs.divers.flatten().astype(jnp.int32), obs.enemy_missiles.flatten().astype(jnp.int32), self.flatten_entity_position(obs.surface_submarine), self.flatten_entity_position(obs.player_missile), obs.collected_divers.flatten().astype(jnp.int32), obs.player_score.flatten().astype(jnp.int32), obs.lives.flatten().astype(jnp.int32), obs.oxygen_level.flatten().astype(jnp.int32), ])
[docs] def action_space(self) -> spaces.Discrete: return spaces.Discrete(len(self.ACTION_SET))
[docs] def observation_space(self) -> spaces.Dict: """Returns the observation space for Seaquest. The observation contains: - player: PlayerEntity (x, y, o, width, height, active) - sharks: array of shape (12, 5) with x,y,width,height,active for each shark - submarines: array of shape (12, 5) with x,y,width,height,active for each submarine - divers: array of shape (4, 5) with x,y,width,height,active for each diver - enemy_missiles: array of shape (4, 5) with x,y,width,height,active for each missile - surface_submarine: EntityPosition (x, y, width, height, active) - player_missile: EntityPosition (x, y, width, height, active) - collected_divers: int (0-6) - player_score: int (0-999999) - lives: int (0-3) - oxygen_level: int (0-255) """ return spaces.Dict({ "player": spaces.Dict({ "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "o": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "active": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }), "sharks": spaces.Box(low=0, high=160, shape=(12, 5), dtype=jnp.int32), "submarines": spaces.Box(low=0, high=160, shape=(12, 5), dtype=jnp.int32), "divers": spaces.Box(low=0, high=160, shape=(4, 5), dtype=jnp.int32), "enemy_missiles": spaces.Box(low=0, high=160, shape=(4, 5), dtype=jnp.int32), "surface_submarine": spaces.Dict({ "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "active": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }), "player_missile": spaces.Dict({ "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "active": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }), "collected_divers": spaces.Box(low=0, high=6, shape=(), dtype=jnp.int32), "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), "lives": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), "oxygen_level": spaces.Box(low=0, high=255, shape=(), dtype=jnp.int32), })
[docs] def image_space(self) -> spaces.Box: """Returns the image space for Seaquest. The image is a RGB image with shape (210, 160, 3). """ return spaces.Box( low=0, high=255, shape=(210, 160, 3), dtype=jnp.uint8 )
@partial(jax.jit, static_argnums=(0, )) def _get_observation(self, state: SeaquestState) -> SeaquestObservation: # Create player (already scalar, no need for vectorization) player = PlayerEntity( x=state.player_x, y=state.player_y, o=state.player_direction, width=jnp.array(self.consts.PLAYER_SIZE[0]), height=jnp.array(self.consts.PLAYER_SIZE[1]), active=jnp.array(1), # Player is always active ) # Define a function to convert enemy positions to entity format def convert_to_entity(pos, size): return jnp.array([ pos[0], # x position pos[1], # y position size[0], # width size[1], # height pos[2] != 0, # active flag ]) # Apply conversion to each type of entity using vmap # Sharks sharks = jax.vmap(lambda pos: convert_to_entity(pos, self.consts.SHARK_SIZE))( state.shark_positions ) # Submarines submarines = jax.vmap(lambda pos: convert_to_entity(pos, self.consts.ENEMY_SUB_SIZE))( state.sub_positions ) # Divers divers = jax.vmap(lambda pos: convert_to_entity(pos, self.consts.DIVER_SIZE))( state.diver_positions ) # Enemy missiles enemy_missiles = jax.vmap(lambda pos: convert_to_entity(pos, self.consts.MISSILE_SIZE))( state.enemy_missile_positions ) # Surface submarine (scalar) surface_pos = state.surface_sub_position surface_sub = EntityPosition( x=surface_pos[0], # First item of first dimension y=surface_pos[1], # First item of second dimension width=jnp.array(self.consts.ENEMY_SUB_SIZE[0]), height=jnp.array(self.consts.ENEMY_SUB_SIZE[1]), active=jnp.array(surface_pos[2] != 0), ) # Player missile (scalar) missile_pos = state.player_missile_position player_missile = EntityPosition( x=missile_pos[0], y=missile_pos[1], width=jnp.array(self.consts.MISSILE_SIZE[0]), height=jnp.array(self.consts.MISSILE_SIZE[1]), active=jnp.array(missile_pos[2] != 0), ) # Return observation return SeaquestObservation( player=player, sharks=sharks, submarines=submarines, divers=divers, enemy_missiles=enemy_missiles, surface_submarine=surface_sub, player_missile=player_missile, collected_divers=state.divers_collected, player_score=state.score, lives=state.lives, oxygen_level=state.oxygen, ) @partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: SeaquestState) -> SeaquestInfo: return SeaquestInfo( successful_rescues=state.successful_rescues, difficulty=state.spawn_state.difficulty, step_counter=state.step_counter, ) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: SeaquestState, state: SeaquestState): return state.score - previous_state.score @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: SeaquestState) -> bool: return state.lives < 0
[docs] @partial(jax.jit, static_argnums=(0,)) def reset(self, key: jax.random.PRNGKey = jax.random.PRNGKey(42)) -> Tuple[SeaquestObservation, SeaquestState]: """Initialize game state""" reset_state = SeaquestState( player_x=jnp.array(self.consts.PLAYER_START_X), player_y=jnp.array(self.consts.PLAYER_START_Y), player_direction=jnp.array(0), oxygen=jnp.array(0), # Full oxygen divers_collected=jnp.array(0), score=jnp.array(0), lives=jnp.array(3), spawn_state=self.initialize_spawn_state(), diver_positions=jnp.zeros((self.consts.MAX_DIVERS, 3)), # 4 divers shark_positions=jnp.zeros((self.consts.MAX_SHARKS, 3)), sub_positions=jnp.zeros((self.consts.MAX_SUBS, 3)), # x, y, direction enemy_missile_positions=jnp.zeros((self.consts.MAX_ENEMY_MISSILES, 3)), # 4 missiles surface_sub_position=jnp.zeros(3), # 1 surface sub player_missile_position=jnp.zeros(3), # x,y,direction step_counter=jnp.array(0), just_surfaced=jnp.array(-1), successful_rescues=jnp.array(0), death_counter=jnp.array(0), rng_key=key, ) initial_obs = self._get_observation(reset_state) return initial_obs, reset_state
[docs] @partial(jax.jit, static_argnums=(0, )) def step( self, state: SeaquestState, action: chex.Array ) -> Tuple[SeaquestObservation, SeaquestState, float, bool, SeaquestInfo]: # Translate compact agent action index to ALE console action atari_action = jnp.take(self.ACTION_SET, action.astype(jnp.int32)) previous_state = state _, reset_state = self.reset(state.rng_key) # First handle death animation if active def handle_death_animation(): # This outer conditional remains the same. # It decides if the animation is over or still running. is_animation_over = state.death_counter <= 1 def on_animation_continue(): # This is the original logic for when the animation is still running. # It correctly updates the Y-positions and the player visibility. shark_y_positions, _, _, _ = self.step_enemy_movement( state.spawn_state, state.shark_positions, state.sub_positions, state.step_counter, state.rng_key, ) new_shark_positions = state.shark_positions.at[:, 1].set( shark_y_positions[:, 1] ) should_hide_player = state.death_counter <= 45 return state._replace( death_counter=state.death_counter - 1, shark_positions=new_shark_positions, sub_positions=state.sub_positions, enemy_missile_positions=state.enemy_missile_positions, player_missile_position=jnp.zeros(3), player_x=jnp.where(should_hide_player, -100, state.player_x), step_counter=state.step_counter + 1, ) def on_animation_over(): # This is the new, more precise logic for when the animation ends. # We add a check to see if this is the absolute final life. is_final_life = state.lives <= 0 def handle_game_over(): # If this is the final life, return the TRUE final state of the game # while setting lives to -1 to trigger done=True. # This is what your snapshot test needs to see. return state._replace( lives=state.lives - 1, death_counter=0, ) def handle_stage_reset(): # If the player still has lives left, perform the original stage reset. # This preserves the mechanic of resetting the level after losing a life. return reset_state._replace( lives=state.lives - 1, score=state.score, successful_rescues=state.successful_rescues, divers_collected=jnp.maximum(state.divers_collected - 1, 0), spawn_state=self.soft_reset_spawn_state(state.spawn_state), ) # Use the new nested conditional to choose the correct outcome. return jax.lax.cond( is_final_life, lambda: handle_game_over(), lambda: handle_stage_reset(), ) # This is the main conditional call. return jax.lax.cond( is_animation_over, lambda: on_animation_over(), lambda: on_animation_continue(), ) def handle_score_freeze(): # on scoring, the death counter will be set to -(oxygen * 2 + 16 * 6) # thats when we get in here, so duplicate the death animation pattern, but decrease the oxygen until its 0 # Calculate new positions with frozen X coordinates shark_y_positions, _, _, _ = self.step_enemy_movement( state.spawn_state, state.shark_positions, state.sub_positions, state.step_counter, state.rng_key, ) # Keep X positions from original state, only update Y new_shark_positions = state.shark_positions.at[:, 1].set( shark_y_positions[:, 1] ) # calculate the new oxygen new_ox = jnp.where( state.death_counter % 2 == 0, state.oxygen - 1, state.oxygen ) new_ox = jnp.where(new_ox <= 0, jnp.array(0), state.oxygen) # Return either final reset or animation frame return jax.lax.cond( state.death_counter >= -1, lambda _: reset_state._replace( player_x=state.player_x, player_y=state.player_y, player_direction=state.player_direction, score=state.score, lives=state.lives, successful_rescues=state.successful_rescues, divers_collected=jnp.array(0), spawn_state=self.soft_reset_spawn_state(state.spawn_state), surface_sub_position=state.surface_sub_position, oxygen=jnp.array(0), ), lambda _: state._replace( death_counter=state.death_counter + 1, shark_positions=new_shark_positions, sub_positions=state.sub_positions, enemy_missile_positions=state.enemy_missile_positions, player_missile_position=jnp.zeros(3), step_counter=state.step_counter + 1, oxygen=new_ox, ), operand=None, ) # Normal game logic starts here def normal_game_step(): # First check if player should be frozen for oxygen refill at_surface = state.player_y == 46 needs_oxygen = state.oxygen < 64 should_block = jnp.logical_and(at_surface, needs_oxygen) # while player is frozen, keep resetting the spawn counter new_spawn_state = jax.lax.cond( should_block, lambda: state.spawn_state._replace( spawn_timers=jnp.array([80, 80, 80, 120], dtype=jnp.int32) ), lambda: state.spawn_state, ) state_updated = state._replace(spawn_state=new_spawn_state) # If blocked, force position and disable actions player_x = jnp.where(should_block, state.player_x, state.player_x) player_y = jnp.where( should_block, jnp.array(46, dtype=jnp.int32), state.player_y ) action_mod = jnp.where(should_block, jnp.array(Action.NOOP), atari_action) # Now calculate movement using potentially modified positions and action next_x, next_y, player_direction = self.player_step( state._replace(player_x=player_x, player_y=player_y), action_mod ) player_missile_position = self.player_missile_step( state, next_x, next_y, action_mod ) # Rest of oxygen handling and game logic ( new_oxygen, player_x, player_y, player_missile_position, oxygen_depleted, lose_life_surfacing, new_divers_collected, should_reset, new_just_surfaced, new_difficulty, ) = self.update_oxygen(state, next_x, next_y, player_missile_position) # Update divers collected count from oxygen mechanics state_updated = state_updated._replace( divers_collected=new_divers_collected ) # update the spawn state with the new difficulty new_spawn_state = state_updated.spawn_state._replace( difficulty=new_difficulty ) # Check missile collisions ( player_missile_position, new_shark_positions, new_sub_positions, new_score, updated_spawn_state, new_rng_key, ) = self.check_missile_collisions( player_missile_position, state_updated.shark_positions, state_updated.sub_positions, state_updated.score, state_updated.successful_rescues, new_spawn_state, state.rng_key, ) # perform all necessary spawn steps ( new_spawn_state, new_shark_positions, new_sub_positions, new_diver_positions, new_rng_key, ) = self.spawn_step( state_updated, updated_spawn_state, new_shark_positions, new_sub_positions, state.diver_positions, new_rng_key, ) new_diver_positions, new_divers_collected, new_spawn_state, new_rng_key = ( self.step_diver_movement( new_diver_positions, new_shark_positions, player_x, player_y, state_updated.divers_collected, new_spawn_state, state_updated.step_counter, new_rng_key, ) ) new_surface_sub_pos = self.surface_sub_step(state_updated) state_updated._replace(surface_sub_position=new_surface_sub_pos) # update the enemy missile positions new_enemy_missile_positions = self.enemy_missiles_step( new_sub_positions, state_updated.enemy_missile_positions, state_updated.step_counter, state_updated.spawn_state.difficulty, ) # append the surface submarine to the other submarines for the collision check # check if the player has collided with any of the enemies player_collision, collision_points = self.check_player_collision( player_x, player_y, new_sub_positions, new_shark_positions, new_surface_sub_pos, state_updated.enemy_missile_positions, new_score, state_updated.successful_rescues, ) lose_life = jnp.any( jnp.array([oxygen_depleted, player_collision, lose_life_surfacing]) ) # Start death animation but keep divers intact during animation death_animation_state = state_updated._replace( score=state.score + collision_points, death_counter=jnp.array(90), spawn_state=self.soft_reset_spawn_state(state_updated.spawn_state), ) # Calculate points for rescuing divers. Each diver is worth 50 points. # Each successful rescue adds 50 points with a maximum of 1000 points each. base_points_per_diver = 50 max_points_per_diver = 1000 additional_points_per_rescue = 50 * state.successful_rescues points_per_diver = jnp.minimum( base_points_per_diver + additional_points_per_rescue, max_points_per_diver, ) total_diver_points = points_per_diver * state.divers_collected # Calculate bonus points for remaining oxygen oxygen_bonus = state.oxygen * 20 # Calculate total points for successful rescue total_rescue_points = total_diver_points + oxygen_bonus # TODO: somewhere the oxygen is depleted on surfacing, this currently blocks the slow draining of oxygen (which is not gameplay relevant -> low priority) # scoring freeze, 16 ticks per diver i.e. 6 * 16 and also 2 ticks per remaining oxygen (which is drained!) # Create the scoring state scoring_state = state_updated._replace( player_x=player_x, player_y=player_y, player_direction=player_direction, lives=state_updated.lives, score=state_updated.score + total_rescue_points, successful_rescues=state_updated.successful_rescues + 1, spawn_state=self.soft_reset_spawn_state(state_updated.spawn_state)._replace( difficulty=state_updated.spawn_state.difficulty + 1, survived=state_updated.spawn_state.survived.astype(jnp.int32) ), death_counter=jnp.array(-(96 + state_updated.oxygen * 2)), ) # cap the step counter to 1024 new_step_counter = jnp.where( state_updated.step_counter == 1024, jnp.array(0), state_updated.step_counter + 1, ) # Create the normal returned state normal_returned_state = SeaquestState( player_x=player_x, player_y=player_y, player_direction=player_direction, oxygen=new_oxygen, divers_collected=new_divers_collected, score=new_score, lives=state_updated.lives, spawn_state=new_spawn_state._replace( survived=new_spawn_state.survived.astype(jnp.int32) ), diver_positions=new_diver_positions, shark_positions=new_shark_positions, sub_positions=new_sub_positions, enemy_missile_positions=new_enemy_missile_positions, surface_sub_position=new_surface_sub_pos, player_missile_position=player_missile_position, step_counter=new_step_counter, just_surfaced=new_just_surfaced, successful_rescues=state_updated.successful_rescues, death_counter=jnp.array(0), rng_key=new_rng_key, ) # First handle surfacing with all divers (scoring) intermediate_state = jax.lax.cond( should_reset, lambda _: scoring_state, lambda _: normal_returned_state, operand=None, ) # Then handle life loss - start death animation instead of immediate reset final_state = jax.lax.cond( lose_life, lambda _: death_animation_state, lambda _: intermediate_state, operand=None, ) # Check for additional life every 10,000 points additional_lives = (final_state.score // 10000) - (state.score // 10000) new_lives = jnp.minimum(final_state.lives + additional_lives, 6) # max 6 lives possible # Update the final state with new lives final_state = final_state._replace(lives=new_lives) # Check if the game is over game_over = final_state.lives <= -1 # Handle game over state return jax.lax.cond( game_over, lambda _: state._replace( score=final_state.score, lives=jnp.array(-1), death_counter=jnp.array(0), ), lambda _: final_state, operand=None, ) return_state = jax.lax.cond( state.death_counter > 0, lambda _: handle_death_animation(), lambda _: jax.lax.cond( state.death_counter < 0, lambda _: handle_score_freeze(), lambda _: normal_game_step(), operand=None, ), operand=None, ) # Get observation and info observation = self._get_observation(return_state) done = self._get_done(return_state) env_reward = self._get_reward(previous_state, return_state) info = self._get_info(return_state) # Choose between death animation and normal game step return observation, return_state, env_reward, done, info
[docs] class SeaquestRenderer(JAXGameRenderer): def __init__(self, consts: SeaquestConstants = None): super().__init__() self.consts = consts or SeaquestConstants() self.config = render_utils.RendererConfig( game_dimensions=(210, 160), channels=3, ) self.jr = render_utils.JaxRenderingUtils(self.config) # 1. Start from (possibly modded) asset config provided via constants final_asset_config = list(self.consts.ASSET_CONFIG) # 2. Create procedural assets using modded constants procedural_sprites = self._create_procedural_sprites() # 3. Append procedural assets for name, data in procedural_sprites.items(): final_asset_config.append({'name': name, 'type': 'procedural', 'data': data}) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/seaquest" # 4. Load all assets, create palette, and generate ID masks ( self.PALETTE, self.SHAPE_MASKS, self.BACKGROUND, self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(final_asset_config, sprite_path) self.SHARK_COLOR_MAP = self._precompute_shark_color_map() def _create_procedural_sprites(self) -> dict: """Creates 1x1 pixel sprites to ensure colors are in the palette.""" procedural_sprites = {} for i, color in enumerate(self.consts.SHARK_DIFFICULTY_COLORS): rgba = jnp.array(list(color) + [255], dtype=jnp.uint8).reshape(1, 1, 4) procedural_sprites[f'shark_color_{i}'] = rgba rgba_oxy = jnp.array(list(self.consts.OXYGEN_BAR_COLOR[:3]) + [255], dtype=jnp.uint8).reshape(1, 1, 4) procedural_sprites['oxygen_bar_color'] = rgba_oxy return procedural_sprites def _precompute_shark_color_map(self) -> jnp.ndarray: """Creates a lookup table mapping difficulty (0-7) to a shark color ID.""" color_cycle_indices = jnp.array([0, 1, 2, 3, 0, 1, 0, 3]) cycle_rgb_colors = self.consts.SHARK_DIFFICULTY_COLORS[color_cycle_indices] return jnp.array([self.COLOR_TO_ID[tuple(rgb)] for rgb in np.array(cycle_rgb_colors)]) # --- Sequential Rendering for Batched Objects ---
[docs] def render_object_sequentially(self, current_raster, pos, shape_masks, flip_offsets, anim_idx): """Helper to render a single object with a pre-calculated animation index.""" is_active = pos[2] != 0 return jax.lax.cond( is_active, lambda r: self.jr.render_at_clipped(r, pos[0], pos[1], shape_masks[anim_idx], flip_horizontal=pos[2] == self.consts.FACE_LEFT, flip_offset=flip_offsets), lambda r: r, current_raster )
# --- Render Divers --- # Original cycle: Frame 0 for 16 steps, Frame 1 for 4 steps. Total = 20 steps. def _draw_divers(self, raster, state): diver_anim_idx = jax.lax.select((state.step_counter % 20) < 16, 0, 1) raster = jax.lax.fori_loop( 0, state.diver_positions.shape[0], lambda i, r: self.render_object_sequentially(r, state.diver_positions[i], self.SHAPE_MASKS['diver'], self.FLIP_OFFSETS['diver'], diver_anim_idx), raster ) # --- Render Enemy Subs --- # Original cycle: 3 frames, each for 4 steps. Total = 12 steps. enemy_sub_anim_idx = (state.step_counter % 12) // 4 all_subs = jnp.concatenate([state.sub_positions, state.surface_sub_position[None, :]]) raster = jax.lax.fori_loop( 0, all_subs.shape[0], lambda i, r: self.render_object_sequentially(r, all_subs[i], self.SHAPE_MASKS['enemy_sub'], self.FLIP_OFFSETS['enemy_sub'], enemy_sub_anim_idx), raster ) return raster
[docs] @partial(jax.jit, static_argnames=['self']) def render(self, state: SeaquestState) -> jnp.ndarray: raster = self.BACKGROUND # Use the raw step_counter for precise animation control step_counter = state.step_counter # --- Player & Player Torpedo --- # Original cycle: 3 frames, each shown for 4 steps. Total = 12 steps. player_anim_idx = (step_counter % 12) // 4 raster = self.jr.render_at( raster, state.player_x, state.player_y, self.SHAPE_MASKS['player_sub'][player_anim_idx], flip_horizontal=state.player_direction == self.consts.FACE_LEFT, flip_offset=self.FLIP_OFFSETS['player_sub'] ) torp = state.player_missile_position raster = jax.lax.cond( torp[2] != 0, lambda r: self.jr.render_at_clipped(r, torp[0], torp[1], self.SHAPE_MASKS['player_torp'], flip_horizontal=torp[2] == self.consts.FACE_LEFT), lambda r: r, raster ) raster = self._draw_divers(raster, state) # --- Render Enemy Torpedoes --- # No animation, so index is always 0 raster = jax.lax.fori_loop( 0, state.enemy_missile_positions.shape[0], lambda i, r: self.render_object_sequentially(r, state.enemy_missile_positions[i], self.SHAPE_MASKS['enemy_torp'][None, ...], jnp.zeros(2, dtype=jnp.int32), 0), raster ) # --- Render Sharks --- # Original cycle: Frame 0 for 16 steps, Frame 1 for 8 steps. Total = 24 steps. shark_anim_idx = jax.lax.select((step_counter % 24) < 16, 0, 1) difficulty_idx = state.spawn_state.difficulty % 8 shark_color_id = self.SHARK_COLOR_MAP[difficulty_idx] base_shark_masks = self.SHAPE_MASKS['shark_base'] recolored_shark_masks = jnp.where(base_shark_masks != self.jr.TRANSPARENT_ID, shark_color_id, base_shark_masks) raster = jax.lax.fori_loop( 0, state.shark_positions.shape[0], lambda i, r: self.render_object_sequentially(r, state.shark_positions[i], recolored_shark_masks, self.FLIP_OFFSETS['shark_base'], shark_anim_idx), raster ) # --- UI Elements (Unchanged) --- score_digits = self.jr.int_to_digits(state.score, max_digits=6) raster = self.jr.render_label(raster, 58, 18, score_digits, self.SHAPE_MASKS['digits'], spacing=8, max_digits=6) raster = self.jr.render_indicator(raster, 14, 28, state.lives, self.SHAPE_MASKS['life_indicator'], spacing=10, max_value=3) raster = self.jr.render_indicator(raster, 49, 178, state.divers_collected, self.SHAPE_MASKS['diver_indicator'], spacing=10, max_value=6) oxygen_color_id = self.COLOR_TO_ID[tuple(np.array(self.consts.OXYGEN_BAR_COLOR[:3]))] raster = self.jr.render_bar(raster, 49, 170, state.oxygen, 64, 63, 5, oxygen_color_id, self.jr.TRANSPARENT_ID) raster = self.jr.draw_rects( raster, positions=jnp.array([[0, 0]]), sizes=jnp.array([[8, self.config.game_dimensions[0]]]), color_id=self.BACKGROUND[0, 0] ) return self.jr.render_from_palette(raster, self.PALETTE)