Seaquest Environment
- class jaxatari.games.jax_seaquest.CarryState(missile_pos, shark_pos, sub_pos, score)[source]
Bases:
NamedTuple- missile_pos: Array | ndarray | bool_ | number
Alias for field number 0
- score: Array | ndarray | bool_ | number
Alias for field number 3
- shark_pos: Array | ndarray | bool_ | number
Alias for field number 1
- sub_pos: Array | ndarray | bool_ | number
Alias for field number 2
- class jaxatari.games.jax_seaquest.EntityPosition(x, y, width, height, active)[source]
Bases:
NamedTuple- active: Array
Alias for field number 4
- height: Array
Alias for field number 3
- width: Array
Alias for field number 2
- x: Array
Alias for field number 0
- y: Array
Alias for field number 1
- class jaxatari.games.jax_seaquest.JaxSeaquest(consts: SeaquestConstants | None = None)[source]
Bases:
JaxEnvironment[SeaquestState,SeaquestObservation,SeaquestInfo,SeaquestConstants]- ACTION_SET: Array = Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], dtype=int32)
- action_space() Discrete[source]
Returns the action space of the environment as an array containing the actions that can be taken. Returns: The action space of the environment as an array.
- calculate_kill_points(successful_rescues: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]
Calculate the points awarded for killing a shark or submarine. Sharks and submarines are worth 20 points. The points are increased by 10 for each successful rescue with a maximum of 90.
- check_collision_batch(pos1, size1, pos2_array, size2)[source]
Check collision between one entity and an array of entities
- check_collision_single(pos1, size1, pos2, size2)[source]
Check collision between two single entities
- check_missile_collisions(missile_pos: Array | ndarray | bool_ | number, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, score: Array | ndarray | bool_ | number, successful_rescues: Array | ndarray | bool_ | number, spawn_state: SpawnState, rng_key: Array) tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, SpawnState, Array][source]
Check for collisions between player missile and enemies using a vectorized approach.
- check_player_collision(player_x, player_y, submarine_list, shark_list, surface_sub_pos, enemy_projectile_list, score, successful_rescues) Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number][source]
- enemy_missiles_step(curr_sub_positions, curr_enemy_missile_positions, step_counter, difficulty) Array | ndarray | bool_ | number[source]
- flatten_entity_position(entity: EntityPosition) Array[source]
- flatten_player_entity(entity: PlayerEntity) Array[source]
- get_pattern_for_difficulty(current_pattern: Array | ndarray | bool_ | number, moving_left: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]
Returns spawn pattern based on the lane’s current wave/pattern number
Pattern meanings: 0: Single enemy (initial pattern) 1: Two adjacent enemies 2: Two enemies with gap 3: Three enemies in a row
- get_spawn_position(moving_left: Array | ndarray | bool_ | number, slot: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]
Get spawn position based on movement direction and slot number
- image_space() Box[source]
Returns the image space for Seaquest. The image is a RGB image with shape (210, 160, 3).
- initialize_spawn_state() SpawnState[source]
Initialize spawn state with first wave matching original game.
- is_slot_empty(pos: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]
Check if a position slot is empty (0,0,ß)
- obs_to_flat_array(obs: SeaquestObservation) Array[source]
- observation_space() Dict[source]
Returns the observation space for Seaquest. The observation contains: - player: PlayerEntity (x, y, o, width, height, active) - sharks: array of shape (12, 5) with x,y,width,height,active for each shark - submarines: array of shape (12, 5) with x,y,width,height,active for each submarine - divers: array of shape (4, 5) with x,y,width,height,active for each diver - enemy_missiles: array of shape (4, 5) with x,y,width,height,active for each missile - surface_submarine: EntityPosition (x, y, width, height, active) - player_missile: EntityPosition (x, y, width, height, active) - collected_divers: int (0-6) - player_score: int (0-999999) - lives: int (0-3) - oxygen_level: int (0-255)
- player_missile_step(state: SeaquestState, curr_player_x, curr_player_y, action: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]
- player_step(state: SeaquestState, action: Array | ndarray | bool_ | number) tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number][source]
- render(state: SeaquestState) Array[source]
Render the game state to a raster image.
- reset(key: PRNGKey = Array([0, 42], dtype=uint32)) Tuple[SeaquestObservation, SeaquestState][source]
Initialize game state
- soft_reset_spawn_state(spawn_state: SpawnState) SpawnState[source]
Reset spawn_times
- spawn_divers(spawn_state: SpawnState, diver_positions: Array | ndarray | bool_ | number, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, step_counter: Array | ndarray | bool_ | number) tuple[Array | ndarray | bool_ | number, SpawnState][source]
Vectorized function to spawn divers according to pattern that depends on collection state.
- spawn_step(state, spawn_state: SpawnState, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, diver_positions: Array | ndarray | bool_ | number, rng_key: Array) Tuple[SpawnState, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number][source]
Main spawn handling function to be called in game step
- step(state: SeaquestState, action: Array | ndarray | bool_ | number) Tuple[SeaquestObservation, SeaquestState, float, bool, SeaquestInfo][source]
- step_diver_movement(diver_positions: Array | ndarray | bool_ | number, shark_positions: Array | ndarray | bool_ | number, state_player_x: Array | ndarray | bool_ | number, state_player_y: Array | ndarray | bool_ | number, state_divers_collected: Array | ndarray | bool_ | number, spawn_state: SpawnState, step_counter: Array | ndarray | bool_ | number, rng: Array) tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, SpawnState, Array][source]
Move divers according to their pattern and handle collisions. Returns updated diver positions, number of collected divers, updated spawn state, and updated RNG key.
- step_enemy_movement(spawn_state: SpawnState, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, step_counter: Array | ndarray | bool_ | number, rng: Array) Tuple[Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, SpawnState, Array][source]
Update enemy positions based on their patterns
- surface_sub_step(state: SeaquestState) Array | ndarray | bool_ | number[source]
- update_enemy_spawns(spawn_state: SpawnState, shark_positions: Array | ndarray | bool_ | number, sub_positions: Array | ndarray | bool_ | number, diver_positions: Array | ndarray | bool_ | number, step_counter: Array | ndarray | bool_ | number, rng: Array = None) Tuple[SpawnState, Array | ndarray | bool_ | number, Array | ndarray | bool_ | number, Array][source]
Update enemy spawns using pattern-based system matching original game. :param spawn_state: Current spawn state :param shark_positions: Current shark positions :param sub_positions: Current submarine positions :param diver_positions: Current diver positions :param step_counter: Current step counter :param rng: Optional random key for direction randomization
- Returns:
Tuple of updated spawn state, shark positions, sub positions, and updated RNG key
- class jaxatari.games.jax_seaquest.PlayerEntity(x, y, o, width, height, active)[source]
Bases:
NamedTuple- active: Array
Alias for field number 5
- height: Array
Alias for field number 4
- o: Array
Alias for field number 2
- width: Array
Alias for field number 3
- x: Array
Alias for field number 0
- y: Array
Alias for field number 1
- class jaxatari.games.jax_seaquest.SeaquestConstants(ASSET_CONFIG)[source]
Bases:
NamedTuple- ASSET_CONFIG: tuple
Alias for field number 0
- BACKGROUND_COLOR = (0, 0, 139)
- DIVER_COLOR = (66, 72, 200)
- DIVER_SIZE = (8, 11)
- DIVER_SPAWN_POSITIONS = Array([ 69, 93, 117, 141], dtype=int32)
- ENEMY_MISSILE_Y = Array([ 73, 97, 121, 141], dtype=int32)
- ENEMY_SUB_COLOR = (170, 170, 170)
- ENEMY_SUB_SIZE = (8, 11)
- FACE_LEFT = -1
- FACE_RIGHT = 1
- FIRST_WAVE_DIRS = Array([False, False, False, True], dtype=bool)
- MAX_COLLECTED_DIVERS = 6
- MAX_DIVERS = 4
- MAX_ENEMY_MISSILES = 4
- MAX_PLAYER_TORPS = 1
- MAX_SHARKS = 12
- MAX_SUBS = 12
- MAX_SURFACE_SUBS = 1
- MISSILE_SIZE = (8, 1)
- MISSILE_SPAWN_POSITIONS = Array([ 39, 126], dtype=int32)
- OXYGEN_BAR_COLOR = (214, 214, 214, 255)
- OXYGEN_TEXT_COLOR = (0, 0, 0)
- PLAYER_BOUNDS = ((21, 134), (46, 141))
- PLAYER_COLOR = (187, 187, 53)
- PLAYER_SIZE = (16, 11)
- PLAYER_START_X = 76
- PLAYER_START_Y = 46
- SCORE_COLOR = (210, 210, 64)
- SHARK_DIFFICULTY_COLORS = Array([[ 92, 186, 92], [213, 130, 74], [170, 92, 170], [213, 92, 130], [186, 92, 92]], dtype=int32)
- SHARK_SIZE = (8, 7)
- SPAWN_POSITIONS_Y = Array([ 71, 95, 119, 139], dtype=int32)
- SUBMARINE_Y_OFFSET = 2
- X_BORDERS = (0, 160)
- class jaxatari.games.jax_seaquest.SeaquestInfo(difficulty, successful_rescues, step_counter)[source]
Bases:
NamedTuple- difficulty: Array
Alias for field number 0
- step_counter: Array
Alias for field number 2
- successful_rescues: Array
Alias for field number 1
- class jaxatari.games.jax_seaquest.SeaquestObservation(player, sharks, submarines, divers, enemy_missiles, surface_submarine, player_missile, collected_divers, player_score, lives, oxygen_level)[source]
Bases:
NamedTuple- collected_divers: Array
Alias for field number 7
- divers: Array
Alias for field number 3
- enemy_missiles: Array
Alias for field number 4
- lives: Array
Alias for field number 9
- oxygen_level: Array
Alias for field number 10
- player: PlayerEntity
Alias for field number 0
- player_missile: EntityPosition
Alias for field number 6
- player_score: Array
Alias for field number 8
- sharks: Array
Alias for field number 1
- submarines: Array
Alias for field number 2
- surface_submarine: EntityPosition
Alias for field number 5
- class jaxatari.games.jax_seaquest.SeaquestRenderer(consts: SeaquestConstants | None = None)[source]
Bases:
JAXGameRenderer- render(state: SeaquestState) Array[source]
- class jaxatari.games.jax_seaquest.SeaquestState(player_x, player_y, player_direction, oxygen, divers_collected, score, lives, spawn_state, diver_positions, shark_positions, sub_positions, enemy_missile_positions, surface_sub_position, player_missile_position, step_counter, just_surfaced, successful_rescues, death_counter, rng_key)[source]
Bases:
NamedTuple- death_counter: Array | ndarray | bool_ | number
Alias for field number 17
- diver_positions: Array | ndarray | bool_ | number
Alias for field number 8
- divers_collected: Array | ndarray | bool_ | number
Alias for field number 4
- enemy_missile_positions: Array | ndarray | bool_ | number
Alias for field number 11
- just_surfaced: Array | ndarray | bool_ | number
Alias for field number 15
- lives: Array | ndarray | bool_ | number
Alias for field number 6
- oxygen: Array | ndarray | bool_ | number
Alias for field number 3
- player_direction: Array | ndarray | bool_ | number
Alias for field number 2
- player_missile_position: Array | ndarray | bool_ | number
Alias for field number 13
- player_x: Array | ndarray | bool_ | number
Alias for field number 0
- player_y: Array | ndarray | bool_ | number
Alias for field number 1
- rng_key: Array
Alias for field number 18
- score: Array | ndarray | bool_ | number
Alias for field number 5
- shark_positions: Array | ndarray | bool_ | number
Alias for field number 9
- spawn_state: SpawnState
Alias for field number 7
- step_counter: Array | ndarray | bool_ | number
Alias for field number 14
- sub_positions: Array | ndarray | bool_ | number
Alias for field number 10
- successful_rescues: Array | ndarray | bool_ | number
Alias for field number 16
- surface_sub_position: Array | ndarray | bool_ | number
Alias for field number 12
- class jaxatari.games.jax_seaquest.SpawnState(difficulty, lane_dependent_pattern, to_be_spawned, survived, prev_sub, spawn_timers, diver_array, lane_directions)[source]
Bases:
NamedTuple- difficulty: Array | ndarray | bool_ | number
Alias for field number 0
- diver_array: Array | ndarray | bool_ | number
Alias for field number 6
- lane_dependent_pattern: Array | ndarray | bool_ | number
Alias for field number 1
- lane_directions: Array | ndarray | bool_ | number
Alias for field number 7
- prev_sub: Array | ndarray | bool_ | number
Alias for field number 4
- spawn_timers: Array | ndarray | bool_ | number
Alias for field number 5
- survived: Array | ndarray | bool_ | number
Alias for field number 3
- to_be_spawned: Array | ndarray | bool_ | number
Alias for field number 2
- jaxatari.games.jax_seaquest.get_shark_color_index(difficulty: Array | ndarray | bool_ | number) Array | ndarray | bool_ | number[source]
Determine which shark color to use based on difficulty level. Color cycle: Green -> Yellow -> Pink -> Orange -> Green -> Yellow -> Green -> Orange -> back to start
- Parameters:
difficulty – Current difficulty level (0-7)
- Returns:
0=Green, 1=Yellow, 2=Pink, 3=Orange
- Return type:
Color index