Seaquest Environment

class jaxatari.games.jax_seaquest.CarryState(missile_pos, shark_pos, sub_pos, score)[source]

Bases: NamedTuple

missile_pos: Array | ndarray | bool_ | number

Alias for field number 0

score: Array | ndarray | bool_ | number

Alias for field number 3

shark_pos: Array | ndarray | bool_ | number

Alias for field number 1

sub_pos: Array | ndarray | bool_ | number

Alias for field number 2

class jaxatari.games.jax_seaquest.EntityPosition(x, y, width, height, active)[source]

Bases: NamedTuple

active: Array

Alias for field number 4

height: Array

Alias for field number 3

width: Array

Alias for field number 2

x: Array

Alias for field number 0

y: Array

Alias for field number 1

class jaxatari.games.jax_seaquest.JaxSeaquest(consts: SeaquestConstants | None = None)[source]

Bases: JaxEnvironment[SeaquestState, SeaquestObservation, SeaquestInfo, SeaquestConstants]

ACTION_SET: Array = Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,        17], dtype=int32)
action_space() Discrete[source]

Returns the action space of the environment as an array containing the actions that can be taken. Returns: The action space of the environment as an array.

calculate_kill_points(successful_rescues: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]

Calculate the points awarded for killing a shark or submarine. Sharks and submarines are worth 20 points. The points are increased by 10 for each successful rescue with a maximum of 90.

check_collision_batch(pos1, size1, pos2_array, size2)[source]

Check collision between one entity and an array of entities

check_collision_single(pos1, size1, pos2, size2)[source]

Check collision between two single entities

check_missile_collisions(missile_pos: Array | ndarray | bool_ | number, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, score: Array | ndarray | bool_ | number, successful_rescues: Array | ndarray | bool_ | number, spawn_state: SpawnState, rng_key: Array) tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, SpawnState, Array][source]

Check for collisions between player missile and enemies using a vectorized approach.

check_player_collision(player_x, player_y, submarine_list, shark_list, surface_sub_pos, enemy_projectile_list, score, successful_rescues) Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number][source]
enemy_missiles_step(curr_sub_positions, curr_enemy_missile_positions, step_counter, difficulty) Array | ndarray | bool_ | number[source]
flatten_entity_position(entity: EntityPosition) Array[source]
flatten_player_entity(entity: PlayerEntity) Array[source]
get_front_entity(i, lane_positions)[source]
get_pattern_for_difficulty(current_pattern: Array | ndarray | bool_ | number, moving_left: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]

Returns spawn pattern based on the lane’s current wave/pattern number

Pattern meanings: 0: Single enemy (initial pattern) 1: Two adjacent enemies 2: Two enemies with gap 3: Three enemies in a row

get_spawn_position(moving_left: Array | ndarray | bool_ | number, slot: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]

Get spawn position based on movement direction and slot number

image_space() Box[source]

Returns the image space for Seaquest. The image is a RGB image with shape (210, 160, 3).

initialize_spawn_state() SpawnState[source]

Initialize spawn state with first wave matching original game.

is_slot_empty(pos: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]

Check if a position slot is empty (0,0,ß)

obs_to_flat_array(obs: SeaquestObservation) Array[source]
observation_space() Dict[source]

Returns the observation space for Seaquest. The observation contains: - player: PlayerEntity (x, y, o, width, height, active) - sharks: array of shape (12, 5) with x,y,width,height,active for each shark - submarines: array of shape (12, 5) with x,y,width,height,active for each submarine - divers: array of shape (4, 5) with x,y,width,height,active for each diver - enemy_missiles: array of shape (4, 5) with x,y,width,height,active for each missile - surface_submarine: EntityPosition (x, y, width, height, active) - player_missile: EntityPosition (x, y, width, height, active) - collected_divers: int (0-6) - player_score: int (0-999999) - lives: int (0-3) - oxygen_level: int (0-255)

player_missile_step(state: SeaquestState, curr_player_x, curr_player_y, action: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]
player_step(state: SeaquestState, action: Array | ndarray | bool_ | number) tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number][source]
render(state: SeaquestState) Array[source]

Render the game state to a raster image.

reset(key: PRNGKey = Array([0, 42], dtype=uint32)) Tuple[SeaquestObservation, SeaquestState][source]

Initialize game state

soft_reset_spawn_state(spawn_state: SpawnState) SpawnState[source]

Reset spawn_times

spawn_divers(spawn_state: SpawnState, diver_positions: Array | ndarray | bool_ | number, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, step_counter: Array | ndarray | bool_ | number) tuple[Array | ndarray | bool_ | number, SpawnState][source]

Vectorized function to spawn divers according to pattern that depends on collection state.

spawn_step(state, spawn_state: SpawnState, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, diver_positions: Array | ndarray | bool_ | number, rng_key: Array) Tuple[SpawnState, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number][source]

Main spawn handling function to be called in game step

step(state: SeaquestState, action: Array | ndarray | bool_ | number) Tuple[SeaquestObservation, SeaquestState, float, bool, SeaquestInfo][source]
step_diver_movement(diver_positions: Array | ndarray | bool_ | number, shark_positions: Array | ndarray | bool_ | number, state_player_x: Array | ndarray | bool_ | number, state_player_y: Array | ndarray | bool_ | number, state_divers_collected: Array | ndarray | bool_ | number, spawn_state: SpawnState, step_counter: Array | ndarray | bool_ | number, rng: Array) tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, SpawnState, Array][source]

Move divers according to their pattern and handle collisions. Returns updated diver positions, number of collected divers, updated spawn state, and updated RNG key.

step_enemy_movement(spawn_state: SpawnState, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, step_counter: Array | ndarray | bool_ | number, rng: Array) Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, SpawnState, Array][source]

Update enemy positions based on their patterns

surface_sub_step(state: SeaquestState) Array | ndarray | bool_ | number[source]
update_enemy_spawns(spawn_state: SpawnState, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, diver_positions: Array | ndarray | bool_ | number, step_counter: Array | ndarray | bool_ | number, rng: Array = None) Tuple[SpawnState, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array][source]

Update enemy spawns using pattern-based system matching original game. :param spawn_state: Current spawn state :param shark_positions: Current shark positions :param sub_positions: Current submarine positions :param diver_positions: Current diver positions :param step_counter: Current step counter :param rng: Optional random key for direction randomization

Returns:

Tuple of updated spawn state, shark positions, sub positions, and updated RNG key

update_oxygen(state, player_x, player_y, player_missile_position)[source]

Update oxygen levels and handle surfacing mechanics with proper surfacing detection

class jaxatari.games.jax_seaquest.PlayerEntity(x, y, o, width, height, active)[source]

Bases: NamedTuple

active: Array

Alias for field number 5

height: Array

Alias for field number 4

o: Array

Alias for field number 2

width: Array

Alias for field number 3

x: Array

Alias for field number 0

y: Array

Alias for field number 1

class jaxatari.games.jax_seaquest.SeaquestConstants(ASSET_CONFIG)[source]

Bases: NamedTuple

ASSET_CONFIG: tuple

Alias for field number 0

BACKGROUND_COLOR = (0, 0, 139)
DIVER_COLOR = (66, 72, 200)
DIVER_SIZE = (8, 11)
DIVER_SPAWN_POSITIONS = Array([ 69,  93, 117, 141], dtype=int32)
ENEMY_MISSILE_Y = Array([ 73,  97, 121, 141], dtype=int32)
ENEMY_SUB_COLOR = (170, 170, 170)
ENEMY_SUB_SIZE = (8, 11)
FACE_LEFT = -1
FACE_RIGHT = 1
FIRST_WAVE_DIRS = Array([False, False, False,  True], dtype=bool)
MAX_COLLECTED_DIVERS = 6
MAX_DIVERS = 4
MAX_ENEMY_MISSILES = 4
MAX_PLAYER_TORPS = 1
MAX_SHARKS = 12
MAX_SUBS = 12
MAX_SURFACE_SUBS = 1
MISSILE_SIZE = (8, 1)
MISSILE_SPAWN_POSITIONS = Array([ 39, 126], dtype=int32)
OXYGEN_BAR_COLOR = (214, 214, 214, 255)
OXYGEN_TEXT_COLOR = (0, 0, 0)
PLAYER_BOUNDS = ((21, 134), (46, 141))
PLAYER_COLOR = (187, 187, 53)
PLAYER_SIZE = (16, 11)
PLAYER_START_X = 76
PLAYER_START_Y = 46
SCORE_COLOR = (210, 210, 64)
SHARK_DIFFICULTY_COLORS = Array([[ 92, 186,  92],        [213, 130,  74],        [170,  92, 170],        [213,  92, 130],        [186,  92,  92]], dtype=int32)
SHARK_SIZE = (8, 7)
SPAWN_POSITIONS_Y = Array([ 71,  95, 119, 139], dtype=int32)
SUBMARINE_Y_OFFSET = 2
X_BORDERS = (0, 160)
class jaxatari.games.jax_seaquest.SeaquestInfo(difficulty, successful_rescues, step_counter)[source]

Bases: NamedTuple

difficulty: Array

Alias for field number 0

step_counter: Array

Alias for field number 2

successful_rescues: Array

Alias for field number 1

class jaxatari.games.jax_seaquest.SeaquestObservation(player, sharks, submarines, divers, enemy_missiles, surface_submarine, player_missile, collected_divers, player_score, lives, oxygen_level)[source]

Bases: NamedTuple

collected_divers: Array

Alias for field number 7

divers: Array

Alias for field number 3

enemy_missiles: Array

Alias for field number 4

lives: Array

Alias for field number 9

oxygen_level: Array

Alias for field number 10

player: PlayerEntity

Alias for field number 0

player_missile: EntityPosition

Alias for field number 6

player_score: Array

Alias for field number 8

sharks: Array

Alias for field number 1

submarines: Array

Alias for field number 2

surface_submarine: EntityPosition

Alias for field number 5

class jaxatari.games.jax_seaquest.SeaquestRenderer(consts: SeaquestConstants | None = None)[source]

Bases: JAXGameRenderer

render(state: SeaquestState) Array[source]
render_object_sequentially(current_raster, pos, shape_masks, flip_offsets, anim_idx)[source]

Helper to render a single object with a pre-calculated animation index.

class jaxatari.games.jax_seaquest.SeaquestState(player_x, player_y, player_direction, oxygen, divers_collected, score, lives, spawn_state, diver_positions, shark_positions, sub_positions, enemy_missile_positions, surface_sub_position, player_missile_position, step_counter, just_surfaced, successful_rescues, death_counter, rng_key)[source]

Bases: NamedTuple

death_counter: Array | ndarray | bool_ | number

Alias for field number 17

diver_positions: Array | ndarray | bool_ | number

Alias for field number 8

divers_collected: Array | ndarray | bool_ | number

Alias for field number 4

enemy_missile_positions: Array | ndarray | bool_ | number

Alias for field number 11

just_surfaced: Array | ndarray | bool_ | number

Alias for field number 15

lives: Array | ndarray | bool_ | number

Alias for field number 6

oxygen: Array | ndarray | bool_ | number

Alias for field number 3

player_direction: Array | ndarray | bool_ | number

Alias for field number 2

player_missile_position: Array | ndarray | bool_ | number

Alias for field number 13

player_x: Array | ndarray | bool_ | number

Alias for field number 0

player_y: Array | ndarray | bool_ | number

Alias for field number 1

rng_key: Array

Alias for field number 18

score: Array | ndarray | bool_ | number

Alias for field number 5

shark_positions: Array | ndarray | bool_ | number

Alias for field number 9

spawn_state: SpawnState

Alias for field number 7

step_counter: Array | ndarray | bool_ | number

Alias for field number 14

sub_positions: Array | ndarray | bool_ | number

Alias for field number 10

successful_rescues: Array | ndarray | bool_ | number

Alias for field number 16

surface_sub_position: Array | ndarray | bool_ | number

Alias for field number 12

class jaxatari.games.jax_seaquest.SpawnState(difficulty, lane_dependent_pattern, to_be_spawned, survived, prev_sub, spawn_timers, diver_array, lane_directions)[source]

Bases: NamedTuple

difficulty: Array | ndarray | bool_ | number

Alias for field number 0

diver_array: Array | ndarray | bool_ | number

Alias for field number 6

lane_dependent_pattern: Array | ndarray | bool_ | number

Alias for field number 1

lane_directions: Array | ndarray | bool_ | number

Alias for field number 7

prev_sub: Array | ndarray | bool_ | number

Alias for field number 4

spawn_timers: Array | ndarray | bool_ | number

Alias for field number 5

survived: Array | ndarray | bool_ | number

Alias for field number 3

to_be_spawned: Array | ndarray | bool_ | number

Alias for field number 2

jaxatari.games.jax_seaquest.get_shark_color_index(difficulty: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]

Determine which shark color to use based on difficulty level. Color cycle: Green -> Yellow -> Pink -> Orange -> Green -> Yellow -> Green -> Orange -> back to start

Parameters:

difficulty – Current difficulty level (0-7)

Returns:

0=Green, 1=Yellow, 2=Pink, 3=Orange

Return type:

Color index