Source code for jaxatari.games.jax_kangaroo

import os
from functools import partial
from typing import List, NamedTuple, Tuple, Dict, Any, Optional
import jax
import jax.numpy as jnp
import chex
from jax import Array

import jaxatari.spaces as spaces
from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action
from jaxatari.renderers import JAXGameRenderer
import jaxatari.rendering.jax_rendering_utils as render_utils
from jaxatari.games.kangaroo_levels import (
    LevelConstants,
    Kangaroo_Level_1,
    Kangaroo_Level_2,
    Kangaroo_Level_3,
)


[docs] def get_default_asset_config() -> tuple: # 1. Define the game-specific asset manifest in a clear, declarative way. asset_config = [ {'name': 'background', 'type': 'background', 'file': 'background.npy'}, { 'name': 'ape', 'type': 'group', 'files': ['ape_standing.npy', 'ape_climb_left.npy', 'ape_moving.npy', 'throwing_ape.npy', 'ape_climb_right.npy'] }, { 'name': 'kangaroo', 'type': 'group', 'files': ['kangaroo.npy', 'kangaroo_dead.npy', 'kangaroo_climb.npy', 'kangaroo_ducking.npy', 'kangaroo_jump.npy', 'kangaroo_boxing.npy', 'kangaroo_walk.npy', 'kangaroo_jump_high.npy'] }, { 'name': 'bell', 'type': 'group', 'files': ['bell.npy', 'ringing_bell.npy'] }, { 'name': 'fruit', 'type': 'group', 'files': ['strawberry.npy', 'tomato.npy', 'cherry.npy', 'pineapple.npy'] }, { 'name': 'child', 'type': 'group', 'files': ['child.npy', 'child_jump.npy'] }, {'name': 'coconut', 'type': 'single', 'file': 'coconut.npy'}, {'name': 'falling_coconut', 'type': 'single', 'file': 'falling_coconut.npy'}, {'name': 'lives', 'type': 'single', 'file': 'kangaroo_lives.npy'}, { 'name': 'score_digits', 'type': 'digits', 'pattern': 'score_{}.npy' }, { 'name': 'time_digits', 'type': 'digits', 'pattern': 'time_{}.npy' } ] return asset_config
[docs] class KangarooConstants(NamedTuple): RESET: int = 18 RENDER_SCALE_FACTOR: int = 4 SCREEN_WIDTH: int = 160 SCREEN_HEIGHT: int = 210 PLAYER_WIDTH: int = 8 PLAYER_HEIGHT: int = 24 ENEMY_WIDTH: int = 8 ENEMY_HEIGHT: int = 24 FRUIT_WIDTH: int = 8 FRUIT_HEIGHT: int = 12 MAX_PLATFORMS: int = 10 BELL_WIDTH: int = 6 BELL_HEIGHT: int = 11 CHILD_WIDTH: int = 8 CHILD_HEIGHT: int = 15 MONKEY_WIDTH: int = 6 MONKEY_HEIGHT: int = 15 MONKEY_COLOR: Tuple[int, int, int] = (227, 159, 89) PLAYER_COLOR: Tuple[int, int, int] = (223, 183, 85) ENEMY_COLOR: Tuple[int, int, int] = (227, 151, 89) FRUIT_COLOR_STATE_1: Tuple[int, int, int] = (214, 92, 92) FRUIT_COLOR_STATE_2: Tuple[int, int, int] = (230, 250, 92) FRUIT_COLOR_STATE_3: Tuple[int, int, int] = (255, 92, 250) FRUIT_COLOR_STATE_4: Tuple[int, int, int] = (0, 92, 250) FRUIT_COLOR: list = [ (214, 92, 92), (230, 250, 92), (255, 92, 250), (0, 92, 250), ] COCONUT_COLOR: Tuple[int, int, int] = (162, 98, 33) PLATFORM_COLOR: Tuple[int, int, int] = (162, 98, 33) LADDER_COLOR: Tuple[int, int, int] = (129, 78, 26) BELL_COLOR: Tuple[int, int, int] = (210, 164, 74) PLAYER_START_X: int = 23 PLAYER_START_Y: int = 148 MOVEMENT_SPEED: int = 1 LEFT_CLIP: int = 16 RIGHT_CLIP: int = 144 FALLING_COCONUT_WIDTH: int = 2 FALLING_COCONUT_HEIGHT: int = 3 THROWN_COCONUT_WIDTH: int = 2 THROWN_COCONUT_HEIGHT: int = 3 LADDER_HEIGHT: chex.Array = jnp.array(35) LADDER_WIDTH: chex.Array = jnp.array(8) P_HEIGHT: chex.Array = jnp.array(4) LEVEL_1: LevelConstants = Kangaroo_Level_1 LEVEL_2: LevelConstants = Kangaroo_Level_2 LEVEL_3: LevelConstants = Kangaroo_Level_3 # sprites to enable asset overrides ASSET_CONFIG: tuple = get_default_asset_config()
# -------- Entity Classes --------
[docs] class Entity(NamedTuple): x: chex.Array y: chex.Array w: chex.Array h: chex.Array
[docs] class PlayerState(NamedTuple): # Player position x: chex.Array y: chex.Array vel_x: chex.Array orientation: chex.Array height: chex.Array # crouching is_crouching: chex.Array # jumping is_jumping: chex.Array jump_base_y: chex.Array jump_counter: chex.Array jump_orientation: chex.Array landing_base_y: chex.Array # climbing is_climbing: chex.Array climb_base_y: chex.Array climb_counter: chex.Array cooldown_counter: chex.Array # other is_crashing: chex.Array chrash_timer: chex.Array punch_left: chex.Array punch_right: chex.Array last_stood_on_platform_y: chex.Array walk_animation: chex.Array punch_counter: chex.Array # New field to track consecutive punches needs_release: chex.Array # New field to track if spacebar needs to be released
[docs] class LevelState(NamedTuple): """All level related state variables.""" timer: chex.Array platform_positions: chex.Array platform_sizes: chex.Array ladder_positions: chex.Array ladder_sizes: chex.Array fruit_positions: chex.Array fruit_actives: chex.Array fruit_stages: chex.Array bell_position: chex.Array bell_timer: chex.Array child_position: chex.Array child_velocity: chex.Array child_timer: chex.Array falling_coco_position: chex.Array falling_coco_dropping: chex.Array falling_coco_counter: chex.Array falling_coco_skip_update: chex.Array step_counter: chex.Array monkey_states: chex.Array """ - 0: non-existent - 1: moving down - 2: moving left - 3: throwing - 4: moving right - 5: moving up """ monkey_positions: chex.Array """2D array: [monkey_index, [x, y]]""" monkey_throw_timers: chex.Array spawn_protection: chex.Array coco_positions: chex.Array coco_states: chex.Array """ - 0: non existent - 1: charging - 2: throwing """ spawn_position: chex.Array """ - 0: foot - 1: head """ bell_animation: chex.Array
[docs] class KangarooState(NamedTuple): player: PlayerState level: LevelState score: chex.Array current_level: chex.Array level_finished: chex.Array levelup_timer: chex.Array reset_coords: chex.Array levelup: chex.Array lives: chex.Array
[docs] class KangarooObservation(NamedTuple): player_x: chex.Array player_y: chex.Array player_o: chex.Array platform_positions: chex.Array ladder_positions: chex.Array fruit_positions: chex.Array bell_position: chex.Array child_position: chex.Array falling_coco_position: chex.Array monkey_positions: chex.Array coco_positions: chex.Array
[docs] class KangarooInfo(NamedTuple): score: chex.Array level: chex.Array
[docs] class JaxKangaroo(JaxEnvironment[KangarooState, KangarooObservation, KangarooInfo, KangarooConstants]): # Minimal ALE action set (from scripts/action_space_helper.py) ACTION_SET: jnp.ndarray = jnp.array( [ Action.NOOP, Action.FIRE, Action.UP, Action.RIGHT, Action.LEFT, Action.DOWN, Action.UPRIGHT, Action.UPLEFT, Action.DOWNRIGHT, Action.DOWNLEFT, Action.UPFIRE, Action.RIGHTFIRE, Action.LEFTFIRE, Action.DOWNFIRE, Action.UPRIGHTFIRE, Action.UPLEFTFIRE, Action.DOWNRIGHTFIRE, Action.DOWNLEFTFIRE, ], dtype=jnp.int32, ) def __init__(self, consts: KangarooConstants = None): super().__init__(consts) self.consts = consts or KangarooConstants() self.obs_size = 111 self.renderer = KangarooRenderer(self.consts) @partial(jax.jit, static_argnums=(0,)) def _get_valid_platforms(self, level_constants: LevelConstants) -> chex.Array: return level_constants.platform_positions[:, 0] != -1 @partial(jax.jit, static_argnums=(0, 2), donate_argnums=(1,)) def _get_platforms_below_player(self, state: KangarooState, y_offset=0) -> chex.Array: player_x = state.player.x player_y = state.player.y + y_offset player_bottom_y = player_y + state.player.height level_constants: LevelConstants = self._get_level_constants(state.current_level) platform_positions = level_constants.platform_positions platform_sizes = level_constants.platform_sizes platform_x = platform_positions[:, 0] platform_y = platform_positions[:, 1] platform_width = platform_sizes[:, 0] player_is_within_platform_x = jnp.logical_and( (player_x + self.consts.PLAYER_WIDTH) >= platform_x, player_x <= (platform_x + platform_width), ) platform_is_below_player = player_bottom_y <= platform_y diff_to_platforms = jnp.where( platform_is_below_player, platform_y - player_bottom_y, 1000 ) valid_platforms = self._get_valid_platforms(level_constants) candidate_platforms = ( player_is_within_platform_x & platform_is_below_player & valid_platforms ) masked_diffs = jnp.where(candidate_platforms, diff_to_platforms, 1000) closest_platform_idx = jnp.argmin(masked_diffs) min_diff = masked_diffs[closest_platform_idx] result = jnp.zeros_like(platform_x, dtype=bool) return result.at[closest_platform_idx].set(min_diff < 1000) @partial(jax.jit, static_argnums=(0,)) def _entities_collide_with_threshold( self, e1_x: chex.Array, e1_y: chex.Array, e1_w: chex.Array, e1_h: chex.Array, e2_x: chex.Array, e2_y: chex.Array, e2_w: chex.Array, e2_h: chex.Array, threshold: chex.Array, ) -> chex.Array: """Returns True if rectangles overlap by at least threshold fraction. This only Checks for overlap in the x dimension. e1_x, e1_y, e1_w, e1_h: Entity 1 position and size e2_x, e2_y, e2_w, e2_h: Entity 2 position and size threshold: Minimum fraction of overlap required (0-1) Returns: bool: True if entities overlap by at least threshold fraction, False otherwise """ overlap_start_x = jnp.maximum(e1_x, e2_x) overlap_end_x = jnp.minimum(e1_x + e1_w, e2_x + e2_w) overlap_start_y = jnp.maximum(e1_y, e2_y) overlap_end_y = jnp.minimum(e1_y + e1_h, e2_y + e2_h) # Calculate dimensions of overlap region overlap_width = overlap_end_x - overlap_start_x overlap_height = overlap_end_y - overlap_start_y smallest_entity_width = jnp.minimum(e1_w, e2_w) # Calculate minimum required overlap area based on threshold min_required_overlap = smallest_entity_width * threshold meets_threshold = overlap_width >= min_required_overlap return jnp.where((overlap_width < 0) | (overlap_height < 0), False, meets_threshold) @partial(jax.jit, static_argnums=(0,)) def _entities_collide(self, e1_x: chex.Array, e1_y: chex.Array, e1_w: chex.Array, e1_h: chex.Array, e2_x: chex.Array, e2_y: chex.Array, e2_w: chex.Array, e2_h: chex.Array) -> chex.Array: return self._entities_collide_with_threshold( e1_x, e1_y, e1_w, e1_h, e2_x, e2_y, e2_w, e2_h, 0 ) @partial(jax.jit, static_argnums=(0, 2, 3), donate_argnums=(1,)) def _player_is_above_ladder(self, state: KangarooState, threshold: float = 0.3, virtual_hitbox_height: float = 12.0) -> chex.Array: level_constants: LevelConstants = self._get_level_constants(state.current_level) ladder_x = level_constants.ladder_positions[:, 0] ladder_y = level_constants.ladder_positions[:, 1] ladder_w = level_constants.ladder_sizes[:, 0] ladder_h = level_constants.ladder_sizes[:, 1] return jax.vmap( self._entities_collide_with_threshold, in_axes=(None, None, None, None, 0, 0, 0, 0, None), )( state.player.x, state.player.y + state.player.height, self.consts.PLAYER_WIDTH, virtual_hitbox_height, ladder_x, ladder_y, ladder_w, ladder_h, threshold, ) @partial(jax.jit, static_argnums=(0, 2), donate_argnums=(1,)) def _check_ladder_collisions(self, state: KangarooState, threshold: float = 0.3) -> chex.Array: level_constants: LevelConstants = self._get_level_constants(state.current_level) ladder_x = level_constants.ladder_positions[:, 0] ladder_y = level_constants.ladder_positions[:, 1] ladder_w = level_constants.ladder_sizes[:, 0] ladder_h = level_constants.ladder_sizes[:, 1] return jax.vmap( self._entities_collide_with_threshold, in_axes=(None, None, None, None, 0, 0, 0, 0, None), )( state.player.x, state.player.y + 16, self.consts.PLAYER_WIDTH, state.player.height - 16, ladder_x, ladder_y, ladder_w, ladder_h, threshold, ) @partial(jax.jit, static_argnums=(0, 4), donate_argnums=(1,)) def _player_is_on_ladder(self, state: KangarooState, ladder_pos: chex.Array, ladder_size: chex.Array, threshold: float = 0.3) -> chex.Array: return self._entities_collide_with_threshold( state.player.x, state.player.y, self.consts.PLAYER_WIDTH, state.player.height, ladder_pos[0], ladder_pos[1], ladder_size[0], ladder_size[1], threshold, ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _player_jump_controller(self, state: KangarooState, jump_pressed: chex.Array, ladder_intersect: chex.Array): player_y = state.player.y jump_counter = state.player.jump_counter is_jumping = state.player.is_jumping # If a jump is initiated from a crouch, we must calculate the jump_base_y # from the position the kangaroo WOULD BE IN after standing up. is_crouch_jumping = state.player.is_crouching & jump_pressed crouch_height_adjustment = self.consts.PLAYER_HEIGHT - 16 # Crouch height is 16 player_y_for_jump = jnp.where(is_crouch_jumping, player_y - crouch_height_adjustment, player_y) cooldown_condition = state.player.cooldown_counter > 0 jump_start = ( jump_pressed & ~is_jumping & ~ladder_intersect & ~cooldown_condition & ((player_y + self.consts.PLAYER_HEIGHT) > 28) ) jump_counter = jnp.where(jump_start, 0, jump_counter) jump_orientation = jnp.where( jump_start, state.player.orientation, state.player.jump_orientation ) jump_base_y = jnp.where(jump_start, player_y_for_jump, state.player.jump_base_y) new_landing_base_y = jump_base_y platform_y_below_player = self._get_y_of_platform_below_player(state) # find a new potential landing_base if player is above a higher platform new_landing_base_y = jnp.where( is_jumping & ((platform_y_below_player - self.consts.PLAYER_HEIGHT) == (jump_base_y - 8)) & ~jump_start, platform_y_below_player - self.consts.PLAYER_HEIGHT, new_landing_base_y, ) # --- Allow jumping down: if the player reaches a platform located exactly # --- 8 pixels below the jump base, treat it as a valid landing base too. new_landing_base_y = jnp.where( is_jumping & ((platform_y_below_player - self.consts.PLAYER_HEIGHT) == (jump_base_y + 8)) & ~jump_start, platform_y_below_player - self.consts.PLAYER_HEIGHT, new_landing_base_y, ) is_jumping = is_jumping | jump_start jump_counter = jnp.where(is_jumping, jump_counter + 1, jump_counter) # Calculate vertical offset based on jump phase def offset_for(count): conditions = [ (count > 0) & (count <= 8), (count > 8) & (count < 16), (count >= 16) & (count <= 24), (count > 24) & (count <= 32), (count > 32) & (count < 40), ] values = [ -1, -8, -8, -16, -8, ] return jnp.select(conditions, values, default=0) # check if player is on a new platform and cancel jump if so jump_cancel_up = ( is_jumping & (player_y >= new_landing_base_y) & (new_landing_base_y < jump_base_y) & (jump_counter > 32) ) jump_cancel_down = ( is_jumping & ((player_y + 1) == jump_base_y) # +1 because for some reason the player is 1 pixel below the value I would expect & (new_landing_base_y == (jump_base_y + 8)) & (jump_counter >= 40) ) jump_cancel = jump_cancel_up | jump_cancel_down jump_counter = jnp.where(jump_cancel, 40, jump_counter) jump_base_y = jnp.where(jump_cancel, new_landing_base_y, jump_base_y) new_y = jnp.where(jump_cancel, new_landing_base_y, player_y) new_cooldown_counter = jnp.where(jump_cancel, 8, state.player.cooldown_counter) total_offset = offset_for(jump_counter) new_y = jnp.where(is_jumping & ~jump_cancel, jump_base_y + total_offset, new_y) jump_complete = jump_counter >= 41 is_jumping = jnp.where(jump_complete, False, is_jumping) jump_counter = jnp.where(jump_complete, 0, jump_counter) return_value = ( new_y, jump_counter, is_jumping, jump_base_y, new_landing_base_y, jump_orientation, new_cooldown_counter, ) return jax.lax.cond( state.levelup_timer == 0, lambda: return_value, lambda: ( state.player.y, state.player.jump_counter, state.player.is_jumping, state.player.jump_base_y, state.player.landing_base_y, state.player.jump_orientation, state.player.cooldown_counter, ), ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _player_climb_controller(self, state: KangarooState, y: chex.Array, press_up: chex.Array, press_down: chex.Array, ladder_intersect: chex.Array) -> tuple[Array, Array, Array, Array, Array]: ladder_intersect_below = jnp.any(self._player_is_above_ladder(state)) new_y = y is_climbing = state.player.is_climbing is_climbing = jnp.where(state.player.is_jumping, False, is_climbing) climb_counter = state.player.climb_counter cooldown_over = state.player.cooldown_counter <= 0 climb_start = ( press_up & ~is_climbing & ladder_intersect & ~state.player.is_jumping & cooldown_over ) climb_start_downward = ( press_down & ~is_climbing & ladder_intersect_below & ~state.player.is_jumping & cooldown_over ) is_climbing = is_climbing | climb_start | climb_start_downward climb_counter = jnp.where(climb_start | climb_start_downward, 0, climb_counter) climb_base_y = jnp.where(climb_start, new_y, state.player.climb_base_y) climb_base_y = jnp.where( climb_start_downward, self._get_y_of_platform_below_player(state, 1) - self.consts.PLAYER_HEIGHT, climb_base_y, ) new_y = jnp.where(climb_start, new_y - 8, new_y) new_y = jnp.where(climb_start_downward, new_y + 8, new_y) climb_counter = jnp.where(is_climbing, climb_counter + 1, climb_counter) climb_up = jnp.logical_and(press_up, is_climbing) climb_down = jnp.logical_and(press_down, is_climbing) new_y = jnp.where( jnp.logical_and(climb_up, jnp.equal(climb_counter, 19)), new_y - 8, new_y ) new_y = jnp.where( jnp.logical_and(climb_down, jnp.equal(climb_counter, 19)), new_y + 8, new_y ) set_new_climb_base = ( climb_up & ((self._get_y_of_platform_below_player(state) - state.player.height) >= new_y) & ladder_intersect ) climb_base_y = jnp.where( set_new_climb_base, self._get_y_of_platform_below_player(state) - self.consts.PLAYER_HEIGHT, climb_base_y, ) climb_stop = is_climbing & (new_y >= climb_base_y) & ~climb_start_downward is_climbing = jnp.where(climb_stop, False, is_climbing) is_climbing = jnp.where(ladder_intersect | climb_start_downward, is_climbing, False) climb_counter = jnp.where(climb_counter >= 19, 0, climb_counter) cooldown_counter = jnp.where( climb_stop | set_new_climb_base, 15, jnp.where( state.player.cooldown_counter > 0, state.player.cooldown_counter - 1, 0 ), ) return_value = (new_y, is_climbing, climb_base_y, climb_counter, cooldown_counter) return jax.lax.cond( state.levelup_timer == 0, lambda: return_value, lambda: ( state.player.y, state.player.is_climbing, state.player.climb_base_y, state.player.climb_counter, state.player.cooldown_counter, ), ) @partial(jax.jit, static_argnums=(0,)) def _player_height_controller(self, is_jumping: chex.Array, jump_counter: chex.Array, is_crouching: chex.Array) -> chex.Array: def jump_based_height(count): conditions = [ (count < 8), (count < 16), (count < 24), (count < 40), ] values = [ 23, 24, 15, 23, ] return jnp.select(conditions, values, default=24) candidate_height = jump_based_height(jump_counter) height_if_jumping = jnp.where(is_jumping, candidate_height, 24) is_crouching = jnp.logical_and(is_crouching, jnp.logical_not(is_jumping)) new_height = jnp.where(is_crouching, 16, height_if_jumping) return new_height @partial(jax.jit, static_argnums=(0, 2), donate_argnums=(1,)) def _get_y_of_platform_below_player(self, state: KangarooState, y_offset=0) -> chex.Array: level_constants: LevelConstants = self._get_level_constants(state.current_level) platform_bands: jax.Array = self._get_platforms_below_player(state, y_offset) platform_ys = level_constants.platform_positions[:, 1] has_platform_below = jnp.any(platform_bands) platform_y = jnp.sum(platform_bands * platform_ys) return jnp.where(has_platform_below, platform_y, jnp.array(1000)) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _bell_step(self, state: KangarooState) -> Tuple[chex.Array, chex.Array]: """Handles bell collision detection and timer management. Returns: bell_timer: Updated bell timer value respawn_timer_done: Boolean indicating if respawn timer has completed """ bell_collision = self._entities_collide( state.player.x, state.player.y, self.consts.PLAYER_WIDTH, state.player.height, state.level.bell_position[0], state.level.bell_position[1], self.consts.BELL_WIDTH, self.consts.BELL_HEIGHT, ) bell_active = ~jnp.any(state.level.fruit_stages == 3) RESPAWN_AFTER_TICKS = 40 counter = state.level.bell_timer counter_start = bell_collision & (counter == 0) & bell_active counter = jnp.where(counter_start, 1, counter) counter = jnp.where(counter > 0, counter + 1, counter) counter = jnp.where(counter == RESPAWN_AFTER_TICKS + 1, 0, counter) respawn_timer_done = counter == RESPAWN_AFTER_TICKS return counter, respawn_timer_done @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _fruits_step(self, state: KangarooState) -> Tuple[chex.Array, chex.Array]: fruit_x = state.level.fruit_positions[:, 0] fruit_y = state.level.fruit_positions[:, 1] def check_fruit(p_x, p_y, p_w, p_h, f_x, f_y, f_w, f_h, stage, active): fruit_collision = self._entities_collide(p_x, p_y, p_w, p_h, f_x, f_y, f_w, f_h) collision_condition = jnp.logical_and(fruit_collision, active) return jnp.where(collision_condition, 100 * (2**stage), 0), jnp.where( collision_condition, False, active ) (score_additions, new_activations) = jax.vmap( check_fruit, in_axes=(None, None, None, None, 0, 0, None, None, 0, 0) )( state.player.x, state.player.y, self.consts.PLAYER_WIDTH, state.player.height, fruit_x, fruit_y, self.consts.FRUIT_WIDTH, self.consts.FRUIT_HEIGHT, state.level.fruit_stages, state.level.fruit_actives, ) new_score = jnp.sum(score_additions) bell_timer, respawn_timer_done = self._bell_step(state) def get_new_stages(respawn_timer_done, active, stage): return jnp.where( respawn_timer_done & (~active), jnp.clip(stage + 1, 0, 3), stage, ) new_stages = jax.vmap(get_new_stages, in_axes=(None, 0, 0))( respawn_timer_done, state.level.fruit_actives, state.level.fruit_stages ) activations = jax.lax.cond( respawn_timer_done, lambda: jnp.less_equal(new_stages, jnp.array([3, 3, 3])), lambda: new_activations, ) return new_score, activations, new_stages, bell_timer @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _child_step(self, state: KangarooState) -> Tuple[chex.Array]: RESET_TIMER_AFTER = 50 counter = state.level.child_timer counter = counter + 1 counter = jnp.where(counter > RESET_TIMER_AFTER, 0, counter) reset = counter == RESET_TIMER_AFTER child_velocity = state.level.child_velocity new_child_velocity = jnp.where(reset, child_velocity * -1, child_velocity) new_child_x = jnp.where( state.levelup_timer == 0, jnp.where( (counter % 5) == 0, state.level.child_position[0] + new_child_velocity, state.level.child_position[0], ), state.level.child_position[0], ) new_child_y = state.level.child_position[1] new_child_timer = counter return new_child_timer, new_child_x, new_child_y, new_child_velocity def _pad_array(self, arr: jax.Array, target_size: int): current_size = arr.shape[0] return jnp.pad( arr, ((0, target_size - current_size), (0, 0)), mode="constant", constant_values=-1, ) def _pad_to_size(self, level_constants: LevelConstants, max_platforms: int): return LevelConstants( ladder_positions=self._pad_array(level_constants.ladder_positions, max_platforms), ladder_sizes=self._pad_array(level_constants.ladder_sizes, max_platforms), platform_positions=self._pad_array(level_constants.platform_positions, max_platforms), platform_sizes=self._pad_array(level_constants.platform_sizes, max_platforms), fruit_positions=level_constants.fruit_positions, bell_position=level_constants.bell_position, child_position=level_constants.child_position, ) @partial(jax.jit, static_argnums=(0,)) def _get_level_constants(self, current_level: int) -> LevelConstants: max_platforms = 20 level1_padded = self._pad_to_size(self.consts.LEVEL_1, max_platforms) level2_padded = self._pad_to_size(self.consts.LEVEL_2, max_platforms) level3_padded = self._pad_to_size(self.consts.LEVEL_3, max_platforms) return jax.lax.cond( current_level == 1, lambda _: level1_padded, lambda _: jax.lax.cond( current_level == 2, lambda _: level2_padded, lambda _: level3_padded, operand=None, ), operand=None, ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _player_step(self, state: KangarooState, action: chex.Array): level_constants = self._get_level_constants(state.current_level) x, y = state.player.x, state.player.y old_height = state.player.height old_orientation = state.player.orientation # Get inputs press_right = jnp.any( jnp.array( [ action == Action.RIGHT, action == Action.UPRIGHT, action == Action.DOWNRIGHT, action == Action.RIGHTFIRE, action == Action.UPRIGHTFIRE, action == Action.DOWNRIGHTFIRE, ] ) ) press_left = jnp.any( jnp.array( [ action == Action.LEFT, action == Action.UPLEFT, action == Action.DOWNLEFT, action == Action.LEFTFIRE, action == Action.UPLEFTFIRE, action == Action.DOWNLEFTFIRE, ] ) ) press_up = jnp.any( jnp.array( [ action == Action.UP, action == Action.UPRIGHT, action == Action.UPLEFT, action == Action.UPFIRE, action == Action.UPRIGHTFIRE, action == Action.UPLEFTFIRE, ] ) ) # Store original fire press state before any modifications original_press_fire = jnp.any( jnp.array( [ action == Action.FIRE, action == Action.RIGHTFIRE, action == Action.LEFTFIRE, action == Action.UPFIRE, action == Action.DOWNFIRE, action == Action.UPLEFTFIRE, action == Action.UPRIGHTFIRE, action == Action.DOWNLEFTFIRE, action == Action.DOWNRIGHTFIRE, ] ) ) press_down = jnp.any( jnp.array( [ action == Action.DOWN, action == Action.DOWNLEFT, action == Action.DOWNRIGHT, action == Action.DOWNFIRE, action == Action.DOWNLEFTFIRE, action == Action.DOWNRIGHTFIRE, ] ) ) press_down = jnp.where(state.player.is_jumping, False, press_down) original_press_fire = jnp.where(state.player.is_jumping, False, original_press_fire) original_press_fire = jnp.where( state.player.is_climbing, False, original_press_fire ) press_up = jnp.where(press_down, False, press_up) press_right = jnp.where(state.player.is_climbing, False, press_right) press_left = jnp.where(state.player.is_climbing, False, press_left) is_looking_left = state.player.orientation == -1 is_looking_right = state.player.orientation == 1 # Update punch counter new_punch_counter = jnp.where( original_press_fire, state.player.punch_counter + 1, state.player.punch_counter ) # Reset counter when fire is released new_punch_counter = jnp.where( ~original_press_fire & (state.player.punch_counter > 0), 0, new_punch_counter ) # Set needs_release flag when counter reaches 28 and keep it true until spacebar is released new_needs_release = jnp.where( new_punch_counter >= 28, True, # Need to release spacebar jnp.where( ~original_press_fire, # If spacebar is released False, # Reset the flag state.player.needs_release, # Otherwise keep current state ), ) # Only allow punching if either: # 1. Counter is below 28, or # 2. Spacebar has been released after hitting 28 can_punch = jnp.logical_or(new_punch_counter < 28, ~new_needs_release) press_fire = jnp.where(can_punch, original_press_fire, False) is_punching_left = ( jnp.logical_and(press_fire, is_looking_left) & ~state.player.is_crashing ) is_punching_right = ( jnp.logical_and(press_fire, is_looking_right) & ~state.player.is_crashing ) ladder_intersect_thresh = jnp.any(self._check_ladder_collisions(state)) ladder_intersect_no_thresh = jnp.any(self._check_ladder_collisions(state, 0)) ladder_intersect = jnp.where( state.player.is_climbing, ladder_intersect_no_thresh, ladder_intersect_thresh ) ( new_y, new_jump_counter, new_is_jumping, new_jump_base_y, new_landing_base_y, new_jump_orientation, new_cooldown_counter, ) = self._player_jump_controller(state, press_up, ladder_intersect) ( new_y, new_is_climbing, new_climb_base_y, new_climb_counter, new_cooldown_counter, ) = self._player_climb_controller(state, new_y, press_up, press_down, ladder_intersect) new_is_crouching = press_down & ~new_is_climbing & ~new_is_jumping candidate_vel_x = jnp.where( press_left, -self.consts.MOVEMENT_SPEED, jnp.where(press_right, self.consts.MOVEMENT_SPEED, 0) ) standing_still = jnp.equal(candidate_vel_x, 0) new_orientation = jnp.sign(candidate_vel_x) new_orientation = jnp.where(standing_still, old_orientation, new_orientation) stop_in_air = jnp.logical_and( new_is_jumping, state.player.jump_orientation != new_orientation ) vel_x = jnp.where(stop_in_air, 0, candidate_vel_x) # Detect if a crouch-jump was just initiated. did_crouch_jump = state.player.is_crouching & press_up & ~state.player.is_jumping new_player_height = self._player_height_controller( is_jumping=new_is_jumping, jump_counter=new_jump_counter, is_crouching=new_is_crouching, ) new_player_height = jnp.where( (state.levelup_timer > 0) | state.player.is_crashing, self.consts.PLAYER_HEIGHT, new_player_height, ) # If we just performed a crouch-jump, the jump controller already handled # the y-position adjustment. We set dy to 0 to prevent a double correction. effective_old_height = jnp.where(did_crouch_jump, new_player_height, old_height) dy = effective_old_height - new_player_height new_y = new_y + dy # x-axis movement x = jnp.where( state.level.step_counter % 3 != 0, x, jnp.where( state.player.is_crashing | state.levelup_timer != 0, x, jnp.clip(x + vel_x, self.consts.LEFT_CLIP, self.consts.RIGHT_CLIP - self.consts.PLAYER_WIDTH), ), ) platform_bools: jax.Array = self._get_platforms_below_player(state) platform_ys: jax.Array = level_constants.platform_positions[:, 1] valid_platforms = self._get_valid_platforms(level_constants) valid_and_affecting = jnp.logical_and(platform_bools, valid_platforms) climbing_transition = ~state.player.is_climbing & new_is_climbing & press_down # For each platform, calculate what y would be if player is positioned on it platform_y_values = jnp.where( climbing_transition | state.player.is_jumping, new_y, jnp.clip(new_y, 0, platform_ys - new_player_height) ) masked_platform_y_values = jnp.where(valid_and_affecting, platform_y_values, new_y) platform_dependent_y = jnp.min( jnp.where(valid_and_affecting, masked_platform_y_values, self.consts.SCREEN_HEIGHT) ) y = jnp.where( state.player.is_crashing, jnp.where((y + new_player_height) > self.consts.SCREEN_HEIGHT, y, y + 2), platform_dependent_y, ) final_platform_y = 28 player_on_last_platform = (new_y + new_player_height) == final_platform_y level_finished = ( player_on_last_platform & ~state.level_finished & (state.levelup_timer == 0) ) y = jnp.where(state.levelup_timer == 0, y, state.player.y) x = jnp.where(state.reset_coords, self.consts.PLAYER_START_X, x) y = jnp.where(state.reset_coords, self.consts.PLAYER_START_Y, y) return ( x, y, vel_x, new_is_crouching, new_is_jumping, new_is_climbing, new_jump_counter, new_orientation, new_jump_base_y, new_landing_base_y, new_player_height, new_jump_orientation, new_climb_base_y, new_climb_counter, is_punching_left, is_punching_right, new_cooldown_counter, level_finished, new_punch_counter, new_needs_release, ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _timer_controller(self, state: KangarooState): return jnp.where( state.level.step_counter == 255, state.level.timer - 100, state.level.timer ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _next_level(self, state: KangarooState): RESET_AFTER_TICKS = 256 counter = state.levelup_timer counter_start = state.level_finished & (counter == 0) counter = jnp.where((counter > 0) | counter_start, counter + 1, counter) reset_timer_done = counter == RESET_AFTER_TICKS counter = jnp.where(counter > RESET_AFTER_TICKS, 0, counter) reset_coords = jnp.where(reset_timer_done, jnp.array(True), jnp.array(False)) levelup = jnp.where(reset_timer_done, jnp.array(True), jnp.array(False)) current_level = jnp.where(levelup, state.current_level + 1, state.current_level) return current_level, counter, reset_coords, levelup @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _lives_controller(self, state: KangarooState): is_time_over = state.level.timer <= 0 new_last_stood_on_platform_y = jnp.where( self._get_y_of_platform_below_player(state) == (state.player.y + state.player.height), self._get_y_of_platform_below_player(state), state.player.last_stood_on_platform_y, ) # platform_drop_check() y_of_platform_below_player = self._get_y_of_platform_below_player(state) player_is_falling = ( (state.player.y + state.player.height) == state.player.last_stood_on_platform_y ) & (y_of_platform_below_player > state.player.last_stood_on_platform_y) & (~state.player.is_jumping) # monkey touch check def check_monkey_collision(p_x, p_y, p_w, p_h, m_x, m_y, m_w, m_h, m_state): # Add a small delay before re-enabling collision detection # Only check collision if monkey state is not 0 and not in the process of being punched return jnp.logical_and( self._entities_collide(p_x, p_y, p_w, p_h, m_x, m_y, m_w, m_h), jnp.logical_and( m_state != 0, jnp.logical_not( jnp.logical_and( m_state == 0, jnp.logical_and( m_x == 152, m_y == 5 # If monkey is at spawn position ), ) ), ), ) monkey_collision = jax.vmap( check_monkey_collision, in_axes=(None, None, None, None, 0, 0, None, None, 0), )( state.player.x, state.player.y, self.consts.PLAYER_WIDTH, state.player.height, state.level.monkey_positions[:, 0], state.level.monkey_positions[:, 1], self.consts.MONKEY_WIDTH, self.consts.MONKEY_HEIGHT - 1, state.level.monkey_states, ) player_collided_with_monkey = jnp.any(monkey_collision) def check_collision(p_x, p_y, p_w, p_h, m_x, m_y, m_w, m_h, m_state): return jnp.logical_and( self._entities_collide(p_x, p_y, p_w, p_h, m_x, m_y, m_w, m_h - 1), m_state != 0 ) collision = jax.vmap( check_collision, in_axes=(None, None, None, None, 0, 0, None, None, 0), )( state.player.x, state.player.y, self.consts.PLAYER_WIDTH, state.player.height, state.level.coco_positions[:, 0], state.level.coco_positions[:, 1], self.consts.THROWN_COCONUT_WIDTH, self.consts.THROWN_COCONUT_HEIGHT, state.level.coco_states, ) player_collided_with_horizontal_coco = jnp.any(collision) crashed_falling_coco = self._entities_collide_with_threshold( state.player.x, state.player.y, self.consts.PLAYER_WIDTH, state.player.height - 8, state.level.falling_coco_position[0], state.level.falling_coco_position[1], self.consts.FALLING_COCONUT_WIDTH, self.consts.FALLING_COCONUT_HEIGHT, 0.1, ) remove_live = ( is_time_over | player_is_falling | crashed_falling_coco | player_collided_with_monkey | player_collided_with_horizontal_coco ) & ~state.player.is_crashing new_is_crashing = jnp.where(remove_live, True, state.player.is_crashing) start_timer = ( state.player.is_crashing & (state.player.chrash_timer == 0) & ((state.player.y + state.player.height) > self.consts.SCREEN_HEIGHT) ) RESPAWN_AFTER_TICKS = 40 counter = state.player.chrash_timer counter_start = start_timer counter = jnp.where(counter_start, 1, counter) counter = jnp.where(counter > 0, counter + 1, counter) counter = jnp.where(counter == RESPAWN_AFTER_TICKS + 1, 0, counter) crash_timer_done = counter == RESPAWN_AFTER_TICKS new_is_crashing = jnp.where(crash_timer_done, False, new_is_crashing) return ( jnp.where(remove_live, state.lives - 1, state.lives), new_is_crashing, counter, crash_timer_done, new_last_stood_on_platform_y, ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _falling_coconut_controller(self, state: KangarooState, punching: chex.Array): falling_coco_exists = (state.level.falling_coco_position[0] != 13) | ( state.level.falling_coco_position[1] != -1 ) spawn_new_coco = ~falling_coco_exists & (state.level.step_counter == 255) update_positions = ~state.level.falling_coco_skip_update & ( ((state.level.step_counter % 8) == 0) | spawn_new_coco ) # coco go down or up before dropping coco_down = (state.level.step_counter % 32) < 16 # detect if coco is above player and switch from x-following state to dropping state coco_first_time_above_player = ( ~state.level.falling_coco_dropping & falling_coco_exists & ( ((state.level.falling_coco_position[0] + self.consts.FALLING_COCONUT_WIDTH) > state.player.x) & (state.level.falling_coco_position[0] < (state.player.x + self.consts.PLAYER_WIDTH)) ) & update_positions ) update_positions = jnp.where(coco_first_time_above_player, True, update_positions) new_falling_coco_dropping = jnp.where( coco_first_time_above_player, True, state.level.falling_coco_dropping ) new_falling_coco_skip_update = coco_first_time_above_player new_falling_coco_skip_update = jnp.where( state.level.falling_coco_skip_update & (((state.level.step_counter % 8) == 0) | spawn_new_coco), False, new_falling_coco_skip_update | state.level.falling_coco_skip_update, ) new_falling_coco_counter = jnp.where( update_positions, jnp.where( spawn_new_coco, 0, jnp.where( state.level.falling_coco_dropping & update_positions, # coco is dropping state.level.falling_coco_counter + 1, jnp.where( update_positions & coco_down, # coco is going down state.level.falling_coco_counter + 1, state.level.falling_coco_counter - 1, ), ), ), state.level.falling_coco_counter, ) # detect if player is punching the coco fist_w = 3 fist_h = 4 fist_x = jnp.where( state.player.orientation > 0, state.player.x + self.consts.PLAYER_WIDTH, state.player.x - fist_w, ) fist_y = state.player.y + 8 coco_punching = ( self._entities_collide_with_threshold( fist_x, fist_y, fist_w, fist_h, state.level.falling_coco_position[0], state.level.falling_coco_position[1], self.consts.FALLING_COCONUT_WIDTH, self.consts.FALLING_COCONUT_HEIGHT, 0.01, ) & punching ) score_addition = jnp.where(coco_punching, 200, 0) reset_coco = ( (new_falling_coco_counter > 20) | state.player.is_crashing | coco_punching ) new_falling_coco_counter = jnp.where(reset_coco, 0, new_falling_coco_counter) new_falling_coco_dropping = jnp.where(reset_coco, False, new_falling_coco_dropping) new_falling_coco_position_x = jnp.where( update_positions & ~state.level.falling_coco_dropping & (falling_coco_exists | spawn_new_coco), state.level.falling_coco_position[0] + 2, state.level.falling_coco_position[0], ) new_falling_coco_position_y = jnp.where( update_positions & (falling_coco_exists | spawn_new_coco), 8 * new_falling_coco_counter + 9, state.level.falling_coco_position[1], ) new_falling_coco_position = jnp.where( reset_coco, jnp.array([13, -1]), jnp.array([new_falling_coco_position_x, new_falling_coco_position_y]), ) return ( new_falling_coco_position, new_falling_coco_dropping, new_falling_coco_counter, new_falling_coco_skip_update, score_addition, ) @partial(jax.jit, static_argnums=(0,)) def _update_coco_state( self, old_m_state: chex.Array, new_m_state: chex.Array, old_m_timer: chex.Array, new_m_timer: chex.Array, c_state: chex.Array, c_pos_x: chex.Array, ) -> chex.Array: return jnp.where( (old_m_state != 3) & (new_m_state == 3), 1, jnp.where( (c_state == 1) & (old_m_timer == 3) & (new_m_timer == 2), 2, jnp.where(c_pos_x <= 15, 0, c_state), ), ) @partial(jax.jit, static_argnums=(0,)) def _update_coco_positions( self, new_c_state: chex.Array, old_c_state: chex.Array, stepc: chex.Array, old_c_pos: chex.Array, new_m_pos: chex.Array, spawn_position: chex.Array, ) -> chex.Array: return jnp.where( new_c_state == 2, jnp.where( stepc % 2 == 0, jnp.array([old_c_pos[0] - 2, old_c_pos[1]]), old_c_pos, ), jnp.where( (new_c_state == 1) & (old_c_state == 0), jnp.array( [ new_m_pos[0] - 6, jnp.where( spawn_position, new_m_pos[1] - 5, new_m_pos[1] + self.consts.MONKEY_HEIGHT - self.consts.THROWN_COCONUT_HEIGHT, ), ] ), old_c_pos, ), ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _monkey_controller2(self, state: KangarooState, punching: chex.Array): return ( state.level.monkey_states, # new_monkey_states (all zeros) state.level.monkey_positions, # new_monkey_positions (all spawn coords/off-screen) state.level.monkey_throw_timers, # new_monkey_throw_timers (all zeros) jnp.zeros((), dtype=jnp.int32), # score_addition (0) state.level.coco_positions, # new_coco_positions (all off-screen) state.level.coco_states, # new_coco_states (all zeros) jnp.array(False), # flip (should be False) ) @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def _monkey_controller(self, state: KangarooState, punching: chex.Array): current_monkeys_existing = jnp.sum(state.level.monkey_states != 0) spawn_new_monkey = ( ~state.level.spawn_protection & (current_monkeys_existing < 4) & (state.level.step_counter == 16) ) monkey_states_is_zero = state.level.monkey_states == 0 first_non_existing_monkey_index = jnp.argmin(~monkey_states_is_zero) first_non_existing_monkey_index = jnp.where( jnp.any(monkey_states_is_zero), first_non_existing_monkey_index, jnp.array(-1) ) new_monkey_states = state.level.monkey_states new_monkey_states = jax.lax.cond( spawn_new_monkey, lambda: new_monkey_states.at[first_non_existing_monkey_index].set(1), lambda: new_monkey_states, ) monkey_lower_y = state.level.monkey_positions[:, 1] + self.consts.MONKEY_HEIGHT monkey_on_p1 = monkey_lower_y == 172 monkey_on_p2 = monkey_lower_y == 124 monkey_on_p3 = monkey_lower_y == 76 platform_y_under_player = self._get_y_of_platform_below_player(state) transition_1_to_2 = ( ( monkey_on_p1 & (platform_y_under_player <= 172) & (platform_y_under_player > 124) ) | ( monkey_on_p2 & (platform_y_under_player <= 124) & (platform_y_under_player > 76) ) | ( monkey_on_p3 & (platform_y_under_player <= 76) & (platform_y_under_player > 28) ) ) new_monkey_states = jnp.where( (new_monkey_states == 1) & transition_1_to_2, 2, new_monkey_states ) in_state_1 = new_monkey_states == 1 should_transition = ( state.level.monkey_positions[:, 1] + self.consts.MONKEY_HEIGHT ) >= 172 new_monkey_states = jnp.where( in_state_1 & should_transition, 5, new_monkey_states, ) in_state_2 = new_monkey_states == 2 monkey_x_positions = state.level.monkey_positions[:, 0] min_x_reached = monkey_x_positions <= 107 new_monkey_states = jnp.where( in_state_2 & min_x_reached, 3, new_monkey_states, ) in_state_3 = new_monkey_states == 3 timer_is_zero = state.level.monkey_throw_timers == 0 should_transition = in_state_3 & timer_is_zero & (state.level.monkey_states == 3) new_monkey_states = jnp.where(should_transition, 4, new_monkey_states) in_state_4 = new_monkey_states == 4 monkey_x_positions = state.level.monkey_positions[:, 0] reached_right_position = monkey_x_positions >= 146 new_monkey_states = jnp.where( in_state_4 & reached_right_position, 5, new_monkey_states, ) in_state_5 = new_monkey_states == 5 monkey_y_positions = state.level.monkey_positions[:, 1] reached_top_position = monkey_y_positions <= 5 new_monkey_states = jnp.where( in_state_5 & reached_top_position, 0, new_monkey_states, ) def update_single_monkey_position( state_monkey, position_monkey, new_state_monkey, step_counter ): should_update = step_counter % 16 == 0 pos_state_0 = jnp.array([152, 5]) pos_state_1 = jnp.where( state_monkey == 0, jnp.array([152, 5]), jnp.array([position_monkey[0], position_monkey[1] + 8]), ) pos_state_2 = jnp.array([position_monkey[0] - 3, position_monkey[1]]) pos_state_3 = position_monkey pos_state_4 = jnp.array([position_monkey[0] + 3, position_monkey[1]]) pos_state_5 = jnp.where( state_monkey == 1, jnp.array([146, position_monkey[1]]), jnp.array([position_monkey[0], position_monkey[1] - 16]), ) def new_state(state_monkey): return jnp.array( [ (state_monkey == 0), (state_monkey == 1), (state_monkey == 2), (state_monkey == 3), (state_monkey == 4), (state_monkey == 5), ] ) new_pos = jnp.select( new_state(new_state_monkey), [ pos_state_0, pos_state_1, pos_state_2, pos_state_3, pos_state_4, pos_state_5, ], default=position_monkey, ) return jnp.where(should_update, new_pos, position_monkey) new_monkey_positions = jax.vmap( update_single_monkey_position, in_axes=(0, 0, 0, None) )( state.level.monkey_states, state.level.monkey_positions, new_monkey_states, state.level.step_counter, ) def update_timer(new_state, old_state, current_timer, step_counter): return jnp.where( new_state == 3, jnp.where( old_state == 2, 4, jnp.where(step_counter % 16 == 0, current_timer - 1, current_timer), ), current_timer, ) new_monkey_throw_timers = jax.vmap(update_timer, in_axes=(0, 0, 0, None))( new_monkey_states, state.level.monkey_states, state.level.monkey_throw_timers, state.level.step_counter, ) # Call the extracted _update_coco_state method new_coco_states = jax.vmap(self._update_coco_state, in_axes=(0, 0, 0, 0, 0, 0))( state.level.monkey_states, new_monkey_states, state.level.monkey_throw_timers, new_monkey_throw_timers, state.level.coco_states, state.level.coco_positions[:, 0], ) # Call the extracted _update_coco_positions method new_coco_positions = jax.vmap( self._update_coco_positions, in_axes=(0, 0, None, 0, 0, None) )( new_coco_states, state.level.coco_states, state.level.step_counter, state.level.coco_positions, new_monkey_positions, state.level.spawn_position, ) # Handle punching at the very end fist_w = 3 fist_h = 4 fist_x = jnp.where( state.player.orientation > 0, state.player.x + self.consts.PLAYER_WIDTH, state.player.x - fist_w, ) fist_y = state.player.y + 8 def check_punch(f_x, f_y, f_w, f_h, m_x, m_y, m_w, m_h, m_state, punching): return jnp.logical_and( self._entities_collide(f_x, f_y, f_w, f_h, m_x, m_y, m_w, m_h), jnp.logical_and(m_state != 0, punching), ) monkeys_punched = jax.vmap( check_punch, in_axes=(None, None, None, None, 0, 0, None, None, 0, None), )( fist_x, fist_y, fist_w, fist_h, state.level.monkey_positions[:, 0], state.level.monkey_positions[:, 1], self.consts.MONKEY_WIDTH, self.consts.MONKEY_HEIGHT, state.level.monkey_states, punching, ) score_addition = jnp.sum(monkeys_punched) * 200 new_monkey_states = jax.vmap(lambda a, b: jnp.where(b, 0, a), in_axes=(0, 0))( new_monkey_states, monkeys_punched ) new_monkey_positions = jax.vmap( lambda pos, punched: jnp.where(punched, jnp.array([152, 5]), pos), in_axes=(0, 0), )(new_monkey_positions, monkeys_punched) flip = jnp.any((state.level.monkey_states != 3) & (new_monkey_states == 3)) return ( new_monkey_states, new_monkey_positions, new_monkey_throw_timers, score_addition, new_coco_positions, new_coco_states, flip, )
[docs] @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: KangarooObservation) -> chex.Array: """Converts the observation to a flat array.""" return jnp.concatenate( [ obs.player_x.flatten(), obs.player_y.flatten(), obs.player_o.flatten(), obs.platform_positions.flatten(), obs.ladder_positions.flatten(), obs.fruit_positions.flatten(), obs.bell_position.flatten(), obs.child_position.flatten(), obs.falling_coco_position.flatten(), obs.coco_positions.flatten(), obs.monkey_positions.flatten(), ] )
[docs] def render(self, state: KangarooState) -> jnp.ndarray: return self.renderer.render(state)
[docs] def action_space(self) -> spaces.Discrete: return spaces.Discrete(len(self.ACTION_SET))
[docs] def observation_space(self) -> spaces.Dict: """Returns the observation space for Kangaroo. The observation contains: - player_x: int (0-160) - player_y: int (0-210) - player_o: int (-1 or 1 for orientation) - platform_positions: array of shape (20, 2) with x,y coordinates (0-160, 0-210) - ladder_positions: array of shape (20, 2) with x,y coordinates (0-160, 0-210) - fruit_positions: array of shape (3, 2) with x,y coordinates (0-160, 0-210) - bell_position: array of shape (2,) with x,y coordinates (0-160, 0-210) - child_position: array of shape (2,) with x,y coordinates (0-160, 0-210) - falling_coco_position: array of shape (2,) with x,y coordinates (0-160, 0-210) - monkey_positions: array of shape (4, 2) with x,y coordinates (0-160, 0-210) - coco_positions: array of shape (4, 2) with x,y coordinates (0-160, 0-210) """ return spaces.Dict( { "player_x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "player_y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "player_o": spaces.Box(low=-1, high=1, shape=(), dtype=jnp.int32), "platform_positions": spaces.Box( low=-1, high=210, shape=(20, 2), dtype=jnp.int32 ), "ladder_positions": spaces.Box( low=-1, high=210, shape=(20, 2), dtype=jnp.int32 ), "fruit_positions": spaces.Box( low=-1, high=160, shape=(3, 2), dtype=jnp.int32 ), "bell_position": spaces.Box( low=-1, high=160, shape=(2,), dtype=jnp.int32 ), "child_position": spaces.Box( low=0, high=160, shape=(2,), dtype=jnp.int32 ), "falling_coco_position": spaces.Box( low=-1, high=160, shape=(2,), dtype=jnp.int32 ), "monkey_positions": spaces.Box( low=-1, high=160, shape=(4, 2), dtype=jnp.int32 ), "coco_positions": spaces.Box( low=-1, high=160, shape=(4, 2), dtype=jnp.int32 ), } )
[docs] def image_space(self) -> spaces.Box: return spaces.Box( low=0, high=255, shape=(210, 160, 3), dtype=jnp.uint8 )
[docs] @partial(jax.jit, static_argnums=(0,)) def reset(self, key=None) -> Tuple[ KangarooObservation, KangarooState, ]: state = self.reset_level(1) obs = self._get_observation(state) return obs, state
[docs] @partial(jax.jit, static_argnums=(0,)) def reset_level(self, next_level=1) -> KangarooState: next_level = jnp.clip(next_level, 1, 3) level_constants: LevelConstants = self._get_level_constants(next_level) new_state = KangarooState( player=PlayerState( x=jnp.array(self.consts.PLAYER_START_X), y=jnp.array(self.consts.PLAYER_START_Y), vel_x=jnp.array(0), is_crouching=jnp.array(False), is_jumping=jnp.array(False), is_climbing=jnp.array(False), jump_counter=jnp.array(0), orientation=jnp.array(1), jump_base_y=jnp.array(self.consts.PLAYER_START_Y), landing_base_y=jnp.array(self.consts.PLAYER_START_Y), height=jnp.array(self.consts.PLAYER_HEIGHT), jump_orientation=jnp.array(0), climb_base_y=jnp.array(self.consts.PLAYER_START_Y), climb_counter=jnp.array(0), punch_left=jnp.array(False), punch_right=jnp.array(False), cooldown_counter=jnp.array(0), chrash_timer=jnp.array(0), is_crashing=jnp.array(False), last_stood_on_platform_y=jnp.array(1000), walk_animation=jnp.array(0), punch_counter=jnp.array(0), needs_release=jnp.array(False), ), level=LevelState( bell_position=level_constants.bell_position, bell_timer=jnp.array(0), fruit_positions=level_constants.fruit_positions, fruit_actives=jnp.ones(3, dtype=jnp.bool_), fruit_stages=jnp.zeros(3, dtype=jnp.int32), ladder_positions=level_constants.ladder_positions, ladder_sizes=level_constants.ladder_sizes, platform_positions=level_constants.platform_positions, platform_sizes=level_constants.platform_sizes, child_position=level_constants.child_position, child_timer=jnp.array(0), child_velocity=jnp.array(1), timer=jnp.array(2000), # to be modified falling_coco_position=jnp.array([13, -1]), falling_coco_dropping=jnp.array(False), falling_coco_counter=jnp.array(0), falling_coco_skip_update=jnp.array(False), step_counter=jnp.array(0), monkey_states=jnp.zeros(4, dtype=jnp.int32), monkey_positions=jnp.array([[152, 5], [152, 5], [152, 5], [152, 5]]), monkey_throw_timers=jnp.zeros(4, dtype=jnp.int32), spawn_protection=jnp.array(True), coco_positions=jnp.array( [[-10, -10], [-10, -10], [-10, -10], [-10, -10]] ), coco_states=jnp.zeros(4, dtype=jnp.int32), spawn_position=jnp.array(False), bell_animation=jnp.array(0), ), score=jnp.array(0), current_level=next_level, level_finished=jnp.array(False), levelup_timer=jnp.array(0), reset_coords=jnp.array(False), levelup=jnp.array(False), lives=jnp.array(3), ) return new_state
[docs] @partial(jax.jit, static_argnums=(0,), donate_argnums=(1,)) def step(self, state: KangarooState, action: chex.Array) -> Tuple[KangarooObservation, KangarooState, float, bool, KangarooInfo]: # Translate compact agent action index to ALE console action action = jnp.take(self.ACTION_SET, action.astype(jnp.int32)) reset_cond = jnp.any(jnp.array([action == self.consts.RESET])) ( player_x, player_y, vel_x, is_crouching, is_jumping, is_climbing, jump_counter, orientation, jump_base_y, landing_base_y, new_player_height, new_jump_orientation, climb_base_y, climb_counter, punch_left, punch_right, cooldown_counter, level_finished, punch_counter, needs_release, ) = self._player_step(state, action) new_current_level, new_levelup_timer, new_reset_coords, new_levelup = ( self._next_level(state) ) # Handle fruit collection fruit_score_addition, new_actives, new_fruit_stages, bell_timer = self._fruits_step( state ) child_timer, new_child_x, new_child_y, new_child_velocity = self._child_step(state) new_main_timer = self._timer_controller(state) ( new_falling_coco_position, new_falling_coco_dropping, new_falling_coco_counter, new_falling_coco_skip_update, falling_coco_score_addition, ) = self._falling_coconut_controller(state, punch_left | punch_right) ( new_monkey_states, new_monkey_positions, new_monkey_throw_timers, monkey_hit_score_addition, new_coco_positions, new_coco_states, flip, ) = self._monkey_controller(state, (punch_left | punch_right)) ( new_lives, new_is_crashing, crash_timer, crash_timer_done, new_last_stood_on_platform_y, ) = self._lives_controller(state) # add the time after finishing a level level_switch_score_addition = jnp.where(level_finished, state.level.timer, 0) # add score if levelup from lvl3 to lvl1 score_addition = ( fruit_score_addition + monkey_hit_score_addition + level_switch_score_addition + falling_coco_score_addition ) score_addition = jax.lax.cond( new_current_level == 4, lambda: score_addition + 1400, lambda: score_addition, ) new_current_level = jnp.where(new_current_level == 4, 1, new_current_level) new_bell_animation_timer = jnp.where( bell_timer > 0, jnp.where(state.level.bell_animation == 0, 192, state.level.bell_animation), jnp.where( state.level.bell_animation > 0, state.level.bell_animation - 1, state.level.bell_animation, ), ) new_level_state = jax.lax.cond( new_levelup, lambda: self.reset_level(new_current_level).level, lambda: jax.lax.cond( crash_timer_done, lambda: self.reset_level(state.current_level).level, lambda: LevelState( bell_position=state.level.bell_position, fruit_positions=state.level.fruit_positions, ladder_positions=state.level.ladder_positions, ladder_sizes=state.level.ladder_sizes, platform_positions=state.level.platform_positions, platform_sizes=state.level.platform_sizes, child_position=jnp.array([new_child_x, new_child_y]), timer=new_main_timer, bell_timer=bell_timer, child_timer=child_timer, child_velocity=new_child_velocity, fruit_actives=new_actives, fruit_stages=new_fruit_stages, falling_coco_position=jnp.where( state.levelup_timer == 0, new_falling_coco_position, state.level.falling_coco_position, ), falling_coco_dropping=new_falling_coco_dropping, falling_coco_counter=new_falling_coco_counter, falling_coco_skip_update=new_falling_coco_skip_update, step_counter=(state.level.step_counter + 1) % 256, monkey_positions=jnp.where( state.levelup_timer == 0, new_monkey_positions, state.level.monkey_positions, ), monkey_states=new_monkey_states, monkey_throw_timers=new_monkey_throw_timers, spawn_protection=jnp.where( (state.level.step_counter == 255) & state.level.spawn_protection, False, state.level.spawn_protection, ), coco_positions=new_coco_positions, coco_states=new_coco_states, spawn_position=jnp.where( flip, ~state.level.spawn_position, state.level.spawn_position, ), bell_animation=new_bell_animation_timer, ), ), ) currently_walking = jnp.logical_or( jnp.logical_or( jnp.logical_or(action == Action.RIGHT, action == Action.LEFT), jnp.logical_or(action == Action.UPRIGHT, action == Action.UPLEFT), ), jnp.logical_or(action == Action.DOWNRIGHT, action == Action.DOWNLEFT), ) new_walk_counter = jnp.where( currently_walking, state.player.walk_animation + 1, 0 ) new_walk_counter = jnp.where(new_walk_counter == 16, 0, new_walk_counter) new_player_state = jax.lax.cond( crash_timer_done, lambda: self.reset_level(state.current_level).player, lambda: PlayerState( x=player_x, y=player_y, vel_x=vel_x, is_crouching=is_crouching, is_jumping=is_jumping, is_climbing=is_climbing, jump_counter=jump_counter, orientation=orientation, jump_base_y=jump_base_y, landing_base_y=landing_base_y, height=new_player_height, jump_orientation=new_jump_orientation, climb_base_y=climb_base_y, climb_counter=climb_counter, punch_left=punch_left, punch_right=punch_right, cooldown_counter=cooldown_counter, chrash_timer=crash_timer, is_crashing=new_is_crashing, last_stood_on_platform_y=new_last_stood_on_platform_y, walk_animation=new_walk_counter, punch_counter=punch_counter, needs_release=needs_release, ), ) new_state = jax.lax.cond( reset_cond, lambda: self.reset_level(1), lambda: KangarooState( player=new_player_state, level=new_level_state, score=state.score + score_addition, current_level=new_current_level, level_finished=level_finished, levelup_timer=new_levelup_timer, reset_coords=new_reset_coords, levelup=new_levelup, lives=new_lives, ), ) done = self._get_done(new_state) env_reward = self._get_reward(state, new_state) info = self._get_info(new_state) observation = self._get_observation(new_state) return observation, new_state, env_reward, done, info
@partial(jax.jit, static_argnums=(0,)) def _get_observation(self, state: KangarooState) -> KangarooObservation: fruit_mask = state.level.fruit_actives[:, jnp.newaxis] fruit_positions = jnp.where( fruit_mask, state.level.fruit_positions, jnp.array([-1, -1]) ) bell_mask = jnp.any(state.level.bell_position != jnp.array([-1, -1])) bell_position = jnp.where( bell_mask, state.level.bell_position, jnp.array([-1, -1]) ) falling_coco_mask = state.level.falling_coco_dropping[None] falling_coco_position = jnp.where( falling_coco_mask, state.level.falling_coco_position, jnp.array([-1, -1]) ) monkey_mask = state.level.monkey_states[:, jnp.newaxis] monkey_positions = jnp.where( monkey_mask, state.level.monkey_positions, jnp.array([-1, -1]) ) coco_mask = state.level.coco_states[:, jnp.newaxis] coco_positions = jnp.where( coco_mask, state.level.coco_positions, jnp.array([-1, -1]) ) return KangarooObservation( player_x=state.player.x, player_y=state.player.y, player_o=state.player.orientation, platform_positions=state.level.platform_positions, ladder_positions=state.level.ladder_positions, fruit_positions=fruit_positions, bell_position=bell_position, child_position=state.level.child_position, falling_coco_position=falling_coco_position, monkey_positions=monkey_positions, coco_positions=coco_positions, ) @partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: KangarooState) -> KangarooInfo: return KangarooInfo( score=state.score, level=state.current_level, ) @partial(jax.jit, static_argnums=(0,)) def _get_reward( self, previous_state: KangarooState, state: KangarooState ) -> float: return state.score - previous_state.score @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: KangarooState) -> bool: return jnp.logical_and(state.lives <= 0, state.player.y == 188)
[docs] class KangarooRenderer(JAXGameRenderer): def __init__(self, consts: KangarooConstants = None): """ Initializes the renderer by loading sprites, including level backgrounds. Args: sprite_path: Path to the directory containing sprite .npy files. """ self.consts = consts or KangarooConstants() self.rendering_config = render_utils.RendererConfig( game_dimensions=(210, 160), channels=3, ) self.jr = render_utils.JaxRenderingUtils(self.rendering_config) # Load and process all sprites ( self.PALETTE, self.SHAPE_MASKS, self.BACKGROUND, self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self._load_sprites() # get color ideas from the background for ladder and platform rendering self.LADDER_COLOR_ID = self.COLOR_TO_ID.get((162, 98, 33), 0) self.PLATFORM_COLOR_ID = self.COLOR_TO_ID.get((162, 98, 33), 0) # Pre-calculate static ladder properties (these should be constant even for different sized ladders) -> meaning that ladder heights should be divisible by 4!! self.ladder_rung_height = 4 self.ladder_space_height = 4 def _load_sprites(self): """Defines the asset manifest for Kangaroo and loads them via the utility function.""" sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/kangaroo" # 2. Make one call to the utility function. Done. return self.jr.load_and_setup_assets(self.consts.ASSET_CONFIG, sprite_path) @partial(jax.jit, static_argnums=(0,)) def _render_hook_post_ui(self, raster: jnp.ndarray, state: KangarooState): return raster @partial(jax.jit, static_argnums=(0,)) def _draw_bell(self, raster: jnp.ndarray, state: KangarooState): bell_anim_on = ((state.level.bell_animation >= 176) & (state.level.bell_animation <= 192)) | \ ((state.level.bell_animation >= 128) & (state.level.bell_animation <= 143)) | \ ((state.level.bell_animation >= 80) & (state.level.bell_animation <= 95)) | \ ((state.level.bell_animation >= 32) & (state.level.bell_animation <= 47)) bell_idx = jax.lax.select(bell_anim_on, 1, 0) # 1 for ringing, 0 for still bell_mask = self.SHAPE_MASKS["bell"][bell_idx] flip_bell = ((state.level.bell_animation >= 176) & (state.level.bell_animation <= 192)) | \ ((state.level.bell_animation >= 80) & (state.level.bell_animation <= 95)) should_draw_bell = (state.level.bell_position[0] != -1) & ~jnp.any(state.level.fruit_stages == 3) bell_offset = self.FLIP_OFFSETS["bell"] raster = jax.lax.cond(should_draw_bell, lambda r: self.jr.render_at(r, state.level.bell_position[0].astype(int), state.level.bell_position[1].astype(int), bell_mask, flip_horizontal=flip_bell, flip_offset=bell_offset), lambda r: r, raster) return raster @partial(jax.jit, static_argnums=(0,)) def _draw_ladders(self, raster: jnp.ndarray, state: KangarooState): """Draws the ladders using the utility function.""" return self.jr.draw_ladders( raster, state.level.ladder_positions, state.level.ladder_sizes, self.ladder_rung_height, self.ladder_space_height, self.LADDER_COLOR_ID ) @partial(jax.jit, static_argnums=(0,)) def _draw_single_fruit(self, i, raster, state: KangarooState): """ Draws a single fruit based on index i. Designed to be called within a jax.lax.fori_loop. """ should_draw = state.level.fruit_actives[i] fruit_type = state.level.fruit_stages[i].astype(int) pos = state.level.fruit_positions[i] fruit_mask = self.SHAPE_MASKS["fruit"][fruit_type] fruit_offset = self.FLIP_OFFSETS["fruit"] draw_fn = lambda r: self.jr.render_at( r, pos[0].astype(int), pos[1].astype(int), fruit_mask, flip_offset=fruit_offset ) return jax.lax.cond(should_draw, draw_fn, lambda r: r, raster) @partial(jax.jit, static_argnums=(0,)) def _draw_single_monkey(self, i, raster, state: KangarooState): """ Draws a single monkey based on index i. Designed to be called within a jax.lax.fori_loop. """ state_idx = state.level.monkey_states[i].astype(int) pos = state.level.monkey_positions[i] # Map game state to sprite index monkey_sprite_idx = jnp.array([0, 1, 2, 3, 2, 4])[state_idx] is_walking = (state_idx == 2) | (state_idx == 4) use_standing_anim = is_walking & ((state.level.step_counter % 32) < 16) # Index 0 is 'standing' final_sprite_idx = jax.lax.select(use_standing_anim, 0, monkey_sprite_idx) monkey_mask = self.SHAPE_MASKS["ape"][final_sprite_idx] flip_offset = self.FLIP_OFFSETS["ape"] flip_h = (state_idx == 4) should_draw = (state_idx != 0) draw_fn = lambda r: self.jr.render_at_clipped( r, pos[0].astype(int), pos[1].astype(int), monkey_mask, flip_horizontal=flip_h, flip_offset=flip_offset ) return jax.lax.cond(should_draw, draw_fn, lambda r: r, raster)
[docs] @partial(jax.jit, static_argnums=(0,)) def render(self, state: KangarooState) -> chex.Array: # --- 1. Initialize Raster --- raster = self.jr.create_object_raster(self.BACKGROUND) raster = self.jr.draw_rects( raster, state.level.platform_positions, state.level.platform_sizes, self.PLATFORM_COLOR_ID ) raster = self._draw_ladders(raster, state) # --- 3. Draw Dynamic Objects --- # Fruits - UPDATED CALL raster = jax.lax.fori_loop( 0, state.level.fruit_positions.shape[0], lambda i, r: self._draw_single_fruit(i, r, state), raster ) # Bell raster = self._draw_bell(raster, state) # Monkeys (Apes) raster = jax.lax.fori_loop( 0, state.level.monkey_positions.shape[0], lambda i, r: self._draw_single_monkey(i, r, state), raster ) # Player (Kangaroo) is_walking_anim = (state.player.walk_animation > 6) & (state.player.walk_animation < 16) & \ ~state.player.is_crouching & ~state.player.is_jumping & \ ~state.player.is_climbing & ~state.player.is_crashing is_high_jump = (state.player.jump_counter > 16) & (state.player.jump_counter < 25) # Select sprite index based on player state player_sprite_idx = jax.lax.cond(state.player.is_crashing, lambda: 1, lambda: jax.lax.cond(state.player.is_climbing, lambda: 2, lambda: jax.lax.cond(state.player.is_crouching, lambda: 3, lambda: jax.lax.cond(is_high_jump, lambda: 7, lambda: jax.lax.cond(state.player.is_jumping, lambda: 4, lambda: jax.lax.cond(state.player.punch_left | state.player.punch_right, lambda: 5, lambda: jax.lax.cond(is_walking_anim, lambda: 6, lambda: 0))))))) player_mask = self.SHAPE_MASKS["kangaroo"][player_sprite_idx] flip_offset = self.FLIP_OFFSETS["kangaroo"] flip_player = state.player.orientation < 0 player_y_offset = jax.lax.select(is_walking_anim, -1, 0) raster = self.jr.render_at( raster, x=state.player.x.astype(int), y=(state.player.y + player_y_offset).astype(int), sprite_mask=player_mask, flip_horizontal=flip_player, flip_offset=flip_offset ) # Child is_jumping = (state.level.step_counter % 32) < 16 child_idx = jax.lax.select(is_jumping, 1, 0) child_mask = self.SHAPE_MASKS["child"][child_idx] child_offset = self.FLIP_OFFSETS["child"] flip_child = state.level.child_velocity > 0 raster = jax.lax.cond(state.level.child_position[0] != -1, lambda r: self.jr.render_at(r, state.level.child_position[0].astype(int), state.level.child_position[1].astype(int), child_mask, flip_horizontal=flip_child, flip_offset=child_offset), lambda r: r, raster) # Coconuts coconut_offset = self.FLIP_OFFSETS["falling_coconut"] should_draw_falling_coco = (state.level.falling_coco_position[0] != 13) | (state.level.falling_coco_position[1] != -1) raster = jax.lax.cond(should_draw_falling_coco, lambda r: self.jr.render_at(r, state.level.falling_coco_position[0].astype(int), state.level.falling_coco_position[1].astype(int), self.SHAPE_MASKS["falling_coconut"], flip_offset=coconut_offset), lambda r: r, raster) def _draw_coco(i, current_raster): should_draw = (state.level.coco_states[i] != 0) pos = state.level.coco_positions[i] coco_offset = self.FLIP_OFFSETS["coconut"] draw_fn = lambda r: self.jr.render_at(r, pos[0].astype(int), pos[1].astype(int), self.SHAPE_MASKS["coconut"], flip_offset=coco_offset) return jax.lax.cond(should_draw, draw_fn, lambda r: r, current_raster) raster = jax.lax.fori_loop(0, state.level.coco_positions.shape[0], _draw_coco, raster) # --- 4. Draw UI --- # Score score_digits = self.jr.int_to_digits(state.score, max_digits=6) raster = self.jr.render_label(raster, 105, 182, score_digits, self.SHAPE_MASKS["score_digits"], spacing=8, max_digits=6) # Lives lives_count = jnp.maximum(state.lives.astype(int) - 1, 0) raster = self.jr.render_indicator(raster, 15, 182, lives_count, self.SHAPE_MASKS["lives"], spacing=8, max_value=5) # Timer timer_digits = self.jr.int_to_digits(jnp.maximum(state.level.timer.astype(int), 0), max_digits=4) raster = self.jr.render_label(raster, 80, 190, timer_digits, self.SHAPE_MASKS["time_digits"], spacing=4, max_digits=4) # Hook for modifications raster = self._render_hook_post_ui(raster, state) # --- 5. Final Palette Lookup --- return self.jr.render_from_palette(raster, self.PALETTE)