Source code for jaxatari.games.jax_pong

from jax._src.pjit import JitWrapped
import os
from functools import partial
from typing import NamedTuple, Tuple
import jax
import jax.lax
import jax.numpy as jnp
import chex

import jaxatari.spaces as spaces
from jaxatari.renderers import JAXGameRenderer
from jaxatari.rendering import jax_rendering_utils as render_utils
from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action

def _create_wall_sprite(consts: "PongConstants", height: int) -> jnp.ndarray:
    wall_color_rgba = (*consts.SCORE_COLOR, 255)
    wall_shape = (height, consts.WIDTH, 4)
    return jnp.tile(jnp.array(wall_color_rgba, dtype=jnp.uint8), (*wall_shape[:2], 1))

def _get_default_asset_config() -> tuple:
    """
    Returns the default declarative asset manifest for Pong.
    Kept immutable (tuple of dicts) to fit NamedTuple defaults.
    """
    return (
        {'name': 'background', 'type': 'background', 'file': 'background.npy'},
        {'name': 'player', 'type': 'single', 'file': 'player.npy'},
        {'name': 'enemy', 'type': 'single', 'file': 'enemy.npy'},
        {'name': 'ball', 'type': 'single', 'file': 'ball.npy'},
        {'name': 'player_digits', 'type': 'digits', 'pattern': 'player_score_{}.npy'},
        {'name': 'enemy_digits', 'type': 'digits', 'pattern': 'enemy_score_{}.npy'},
    )


[docs] class PongConstants(NamedTuple): MAX_SPEED: int = 12 BALL_SPEED: chex.Array = jnp.array([-1, 1]) ENEMY_STEP_SIZE: int = 2 WIDTH: int = 160 HEIGHT: int = 210 BASE_BALL_SPEED: int = 1 BALL_MAX_SPEED: int = 4 MIN_BALL_SPEED: int = 1 PLAYER_ACCELERATION: chex.Array = jnp.array([6, 3, 1, -1, 1, -1, 0, 0, 1, 0, -1, 0, 1]) BALL_START_X: chex.Array = jnp.array(78) BALL_START_Y: chex.Array = jnp.array(115) BACKGROUND_COLOR: Tuple[int, int, int] = (144, 72, 17) PLAYER_COLOR: Tuple[int, int, int] = (92, 186, 92) ENEMY_COLOR: Tuple[int, int, int] = (213, 130, 74) BALL_COLOR: Tuple[int, int, int] = (236, 236, 236) WALL_COLOR: Tuple[int, int, int] = (236, 236, 236) SCORE_COLOR: Tuple[int, int, int] = (236, 236, 236) PLAYER_X: int = 140 ENEMY_X: int = 16 PLAYER_SIZE: Tuple[int, int] = (4, 16) BALL_SIZE: Tuple[int, int] = (2, 4) ENEMY_SIZE: Tuple[int, int] = (4, 16) WALL_TOP_Y: int = 24 WALL_TOP_HEIGHT: int = 10 WALL_BOTTOM_Y: int = 194 WALL_BOTTOM_HEIGHT: int = 16 # sset config baked into constants (immutable default) for asset overrides ASSET_CONFIG: tuple = _get_default_asset_config()
# immutable state container
[docs] class PongState(NamedTuple): player_y: chex.Array player_speed: chex.Array ball_x: chex.Array ball_y: chex.Array enemy_y: chex.Array enemy_speed: chex.Array ball_vel_x: chex.Array ball_vel_y: chex.Array player_score: chex.Array enemy_score: chex.Array step_counter: chex.Array acceleration_counter: chex.Array buffer: chex.Array key: chex.PRNGKey
[docs] class EntityPosition(NamedTuple): x: jnp.ndarray y: jnp.ndarray width: jnp.ndarray height: jnp.ndarray
[docs] class PongObservation(NamedTuple): player: EntityPosition enemy: EntityPosition ball: EntityPosition score_player: jnp.ndarray score_enemy: jnp.ndarray
[docs] class PongInfo(NamedTuple): time: jnp.ndarray
[docs] class JaxPong(JaxEnvironment[PongState, PongObservation, PongInfo, PongConstants]): # Minimal ALE action set for Pong: # 0=NOOP, 1=FIRE, 2=RIGHT, 3=LEFT, 4=RIGHTFIRE, 5=LEFTFIRE ACTION_SET: jnp.ndarray = jnp.array( [Action.NOOP, Action.FIRE, Action.RIGHT, Action.LEFT, Action.RIGHTFIRE, Action.LEFTFIRE], dtype=jnp.int32, ) def __init__(self, consts: PongConstants = None): consts = consts or PongConstants() super().__init__(consts) self.renderer = PongRenderer(self.consts) def _player_step(self, state: PongState, action: chex.Array) -> PongState: up = jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE) down = jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE) acceleration = self.consts.PLAYER_ACCELERATION[state.acceleration_counter] touches_wall = jnp.logical_or( state.player_y < self.consts.WALL_TOP_Y, state.player_y + self.consts.PLAYER_SIZE[1] > self.consts.WALL_BOTTOM_Y, ) player_speed = state.player_speed player_speed = jax.lax.cond( jnp.logical_or(jnp.logical_not(jnp.logical_or(up, down)), touches_wall), lambda s: jnp.round(s / 2).astype(jnp.int32), lambda s: s, operand=player_speed, ) direction_change_up = jnp.logical_and(up, state.player_speed > 0) player_speed = jax.lax.cond( direction_change_up, lambda s: 0, lambda s: s, operand=player_speed, ) direction_change_down = jnp.logical_and(down, state.player_speed < 0) player_speed = jax.lax.cond( direction_change_down, lambda s: 0, lambda s: s, operand=player_speed, ) direction_change = jnp.logical_or(direction_change_up, direction_change_down) acceleration_counter = jax.lax.cond( direction_change, lambda _: 0, lambda s: s, operand=state.acceleration_counter, ) player_speed = jax.lax.cond( up, lambda s: jnp.maximum(s - acceleration, -self.consts.MAX_SPEED), lambda s: s, operand=player_speed, ) player_speed = jax.lax.cond( down, lambda s: jnp.minimum(s + acceleration, self.consts.MAX_SPEED), lambda s: s, operand=player_speed, ) new_acceleration_counter = jax.lax.cond( jnp.logical_or(up, down), lambda s: jnp.minimum(s + 1, 15), lambda s: 0, operand=acceleration_counter, ) proposed_player_y = jnp.clip( state.player_y + player_speed, self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - 10, self.consts.WALL_BOTTOM_Y - 4, ) # Match original timing/buffering behavior new_player_y, new_player_speed, new_acc_counter = jax.lax.cond( state.step_counter % 2 == 0, lambda _: (proposed_player_y, player_speed, new_acceleration_counter), lambda _: (state.player_y, state.player_speed, state.acceleration_counter), operand=None, ) buffer = jax.lax.cond( jax.lax.eq(state.buffer, state.player_y), lambda _: new_player_y, lambda _: state.buffer, operand=None, ) final_player_y = state.buffer return PongState( player_y=final_player_y, player_speed=new_player_speed, ball_x=state.ball_x, ball_y=state.ball_y, enemy_y=state.enemy_y, enemy_speed=state.enemy_speed, ball_vel_x=state.ball_vel_x, ball_vel_y=state.ball_vel_y, player_score=state.player_score, enemy_score=state.enemy_score, step_counter=state.step_counter, acceleration_counter=new_acc_counter, buffer=buffer, key=state.key, ) def _ball_step(self, state: PongState, action) -> PongState: ball_x = state.ball_x + state.ball_vel_x ball_y = state.ball_y + state.ball_vel_y wall_bounce = jnp.logical_or( ball_y <= self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - self.consts.BALL_SIZE[1], ball_y >= self.consts.WALL_BOTTOM_Y, ) ball_vel_y = jnp.where(wall_bounce, -state.ball_vel_y, state.ball_vel_y) player_paddle_hit = jnp.logical_and( jnp.logical_and(self.consts.PLAYER_X <= ball_x, ball_x <= self.consts.PLAYER_X + self.consts.PLAYER_SIZE[0]), state.ball_vel_x > 0, ) player_paddle_hit = jnp.logical_and( player_paddle_hit, jnp.logical_and( state.player_y - self.consts.BALL_SIZE[1] <= ball_y, ball_y <= state.player_y + self.consts.PLAYER_SIZE[1] + self.consts.BALL_SIZE[1], ), ) enemy_paddle_hit = jnp.logical_and( jnp.logical_and(self.consts.ENEMY_X <= ball_x, ball_x <= self.consts.ENEMY_X + self.consts.ENEMY_SIZE[0] - 1), state.ball_vel_x < 0, ) enemy_paddle_hit = jnp.logical_and( enemy_paddle_hit, jnp.logical_and( state.enemy_y - self.consts.BALL_SIZE[1] <= ball_y, ball_y <= state.enemy_y + self.consts.ENEMY_SIZE[1] + self.consts.BALL_SIZE[1], ), ) paddle_hit = jnp.logical_or(player_paddle_hit, enemy_paddle_hit) section_height = self.consts.PLAYER_SIZE[1] / 5 hit_position = jnp.where( paddle_hit, jnp.where( player_paddle_hit, jnp.where( ball_y < state.player_y + section_height, -2.0, jnp.where( ball_y < state.player_y + 2 * section_height, -1.0, jnp.where( ball_y < state.player_y + 3 * section_height, 0.0, jnp.where( ball_y < state.player_y + 4 * section_height, 1.0, 2.0, ), ), ), ), jnp.where( ball_y < state.enemy_y + section_height, -2.0, jnp.where( ball_y < state.enemy_y + 2 * section_height, -1.0, jnp.where( ball_y < state.enemy_y + 3 * section_height, 0.0, jnp.where( ball_y < state.enemy_y + 4 * section_height, 1.0, 2.0, ), ), ), ), ), 0.0, ) paddle_speed = jnp.where( player_paddle_hit, state.player_speed, jnp.where( enemy_paddle_hit, state.enemy_speed, 0.0, ), ) ball_vel_y = jnp.where(paddle_hit, hit_position, ball_vel_y) boost_triggered = jnp.logical_and( player_paddle_hit, jnp.logical_or( jnp.logical_or(action == Action.LEFTFIRE, action == Action.RIGHTFIRE), action == Action.FIRE, ), ) player_max_hit = jnp.logical_and(player_paddle_hit, state.player_speed == self.consts.MAX_SPEED) ball_vel_x = jnp.where( jnp.logical_or(boost_triggered, player_max_hit), state.ball_vel_x + jnp.sign(state.ball_vel_x), state.ball_vel_x, ) ball_vel_x = jnp.where( paddle_hit, -ball_vel_x, ball_vel_x, ) return PongState( player_y=state.player_y, player_speed=state.player_speed, ball_x=ball_x.astype(jnp.int32), ball_y=ball_y.astype(jnp.int32), enemy_y=state.enemy_y, enemy_speed=state.enemy_speed, ball_vel_x=ball_vel_x.astype(jnp.int32), ball_vel_y=ball_vel_y.astype(jnp.int32), player_score=state.player_score, enemy_score=state.enemy_score, step_counter=state.step_counter, acceleration_counter=state.acceleration_counter, buffer=state.buffer, key=state.key, ) def _enemy_step(self, state: PongState) -> PongState: should_move = state.step_counter % 8 != 0 direction = jnp.sign(state.ball_y - state.enemy_y) new_y = state.enemy_y + (direction * self.consts.ENEMY_STEP_SIZE).astype(jnp.int32) enemy_y = jax.lax.cond( should_move, lambda _: new_y, lambda _: state.enemy_y, operand=None ) return PongState( player_y=state.player_y, player_speed=state.player_speed, ball_x=state.ball_x, ball_y=state.ball_y, enemy_y=enemy_y.astype(jnp.int32), enemy_speed=state.enemy_speed, ball_vel_x=state.ball_vel_x, ball_vel_y=state.ball_vel_y, player_score=state.player_score, enemy_score=state.enemy_score, step_counter=state.step_counter, acceleration_counter=state.acceleration_counter, buffer=state.buffer, key=state.key, ) def _score_and_reset(self, state: PongState) -> PongState: player_goal = state.ball_x < 4 enemy_goal = state.ball_x > 156 ball_reset = jnp.logical_or(enemy_goal, player_goal) player_score = jax.lax.cond( player_goal, lambda s: s + 1, lambda s: s, operand=state.player_score, ) enemy_score = jax.lax.cond( enemy_goal, lambda s: s + 1, lambda s: s, operand=state.enemy_score, ) current_values = ( state.ball_x.astype(jnp.int32), state.ball_y.astype(jnp.int32), state.ball_vel_x.astype(jnp.int32), state.ball_vel_y.astype(jnp.int32), ) ball_x_final, ball_y_final, ball_vel_x_final, ball_vel_y_final = jax.lax.cond( ball_reset, lambda x: self._reset_ball_after_goal((state, enemy_goal)), lambda x: x, operand=current_values, ) step_counter = jax.lax.cond( ball_reset, lambda s: jnp.array(0), lambda s: s + 1, operand=state.step_counter, ) enemy_y_final = jax.lax.cond( ball_reset, lambda s: self.consts.BALL_START_Y.astype(jnp.int32), lambda s: state.enemy_y.astype(jnp.int32), operand=None, ) ball_x_final = jax.lax.cond( step_counter < 60, lambda s: self.consts.BALL_START_X.astype(jnp.int32), lambda s: s, operand=ball_x_final, ) ball_y_final = jax.lax.cond( step_counter < 60, lambda s: self.consts.BALL_START_Y.astype(jnp.int32), lambda s: s, operand=ball_y_final, ) return PongState( player_y=state.player_y, player_speed=state.player_speed, ball_x=ball_x_final, ball_y=ball_y_final, enemy_y=enemy_y_final, enemy_speed=state.enemy_speed, ball_vel_x=ball_vel_x_final, ball_vel_y=ball_vel_y_final, player_score=player_score, enemy_score=enemy_score, step_counter=step_counter, acceleration_counter=state.acceleration_counter, buffer=state.buffer, key=state.key, ) initial_obs = self._get_observation(state) return initial_obs, state @partial(jax.jit, static_argnums=(0,)) def step(self, state: PongState, action: chex.Array) -> Tuple[PongObservation, PongState, float, bool, PongInfo]: # Translate compact agent action index to ALE console action atari_action = jnp.take(self.ACTION_SET, action.astype(jnp.int32)) previous_state = state state = self._player_step(state, atari_action) state = self._enemy_step(state) state = self._ball_step(state, atari_action) state = self._score_and_reset(state) def _reset_ball_after_goal(self, state_and_goal: Tuple[PongState, bool]) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: state, scored_right = state_and_goal ball_vel_y = jnp.where( state.ball_y > self.consts.BALL_START_Y, 1, -1, ).astype(jnp.int32) ball_vel_x = jnp.where( scored_right, 1, -1 ).astype(jnp.int32) return ( self.consts.BALL_START_X.astype(jnp.int32), self.consts.BALL_START_Y.astype(jnp.int32), ball_vel_x.astype(jnp.int32), ball_vel_y.astype(jnp.int32), )
[docs] def reset(self, key: chex.PRNGKey = jax.random.PRNGKey(42)) -> Tuple[PongObservation, PongState]: # Split key for env reset if needed and for state storage state_key, _step_key = jax.random.split(key) state = PongState( player_y=jnp.array(96).astype(jnp.int32), player_speed=jnp.array(0.0).astype(jnp.int32), ball_x=self.consts.BALL_START_X.astype(jnp.int32), ball_y=self.consts.BALL_START_Y.astype(jnp.int32), enemy_y=jnp.array(115).astype(jnp.int32), enemy_speed=jnp.array(0.0).astype(jnp.int32), ball_vel_x=self.consts.BALL_SPEED[0].astype(jnp.int32), ball_vel_y=self.consts.BALL_SPEED[1].astype(jnp.int32), player_score=jnp.array(0).astype(jnp.int32), enemy_score=jnp.array(0).astype(jnp.int32), step_counter=jnp.array(0).astype(jnp.int32), acceleration_counter=jnp.array(0).astype(jnp.int32), buffer=jnp.array(96).astype(jnp.int32), key=state_key, ) initial_obs = self._get_observation(state) return initial_obs, state
[docs] @partial(jax.jit, static_argnums=(0,)) def step(self, state: PongState, action: chex.Array) -> Tuple[PongObservation, PongState, float, bool, PongInfo]: # Translate compact agent action index to ALE console action atari_action = jnp.take(self.ACTION_SET, action.astype(jnp.int32)) # Split step key from state and keep a new key for the next state new_state_key, step_key = jax.random.split(state.key) previous_state = state # Make per-step key available to helpers that may read state.key state = PongState( player_y=state.player_y, player_speed=state.player_speed, ball_x=state.ball_x, ball_y=state.ball_y, enemy_y=state.enemy_y, enemy_speed=state.enemy_speed, ball_vel_x=state.ball_vel_x, ball_vel_y=state.ball_vel_y, player_score=state.player_score, enemy_score=state.enemy_score, step_counter=state.step_counter, acceleration_counter=state.acceleration_counter, buffer=state.buffer, key=step_key, ) state = self._player_step(state, atari_action) state = self._enemy_step(state) state = self._ball_step(state, atari_action) state = self._score_and_reset(state) # Update state key to new_state_key for next step state = state._replace(key=new_state_key) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) info = self._get_info(state) observation = self._get_observation(state) return observation, state, env_reward, done, info
[docs] def render(self, state: PongState) -> jnp.ndarray: return self.renderer.render(state)
def _get_observation(self, state: PongState): player = EntityPosition( x=jnp.array(self.consts.PLAYER_X), y=state.player_y, width=jnp.array(self.consts.PLAYER_SIZE[0]), height=jnp.array(self.consts.PLAYER_SIZE[1]), ) enemy = EntityPosition( x=jnp.array(self.consts.ENEMY_X), y=state.enemy_y, width=jnp.array(self.consts.ENEMY_SIZE[0]), height=jnp.array(self.consts.ENEMY_SIZE[1]), ) ball = EntityPosition( x=state.ball_x, y=state.ball_y, width=jnp.array(self.consts.BALL_SIZE[0]), height=jnp.array(self.consts.BALL_SIZE[1]), ) return PongObservation( player=player, enemy=enemy, ball=ball, score_player=state.player_score, score_enemy=state.enemy_score, )
[docs] @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: PongObservation) -> jnp.ndarray: return jnp.concatenate([ obs.player.x.flatten(), obs.player.y.flatten(), obs.player.height.flatten(), obs.player.width.flatten(), obs.enemy.x.flatten(), obs.enemy.y.flatten(), obs.enemy.height.flatten(), obs.enemy.width.flatten(), obs.ball.x.flatten(), obs.ball.y.flatten(), obs.ball.height.flatten(), obs.ball.width.flatten(), obs.score_player.flatten(), obs.score_enemy.flatten() ] )
[docs] def action_space(self) -> spaces.Discrete: return spaces.Discrete(len(self.ACTION_SET))
[docs] def observation_space(self) -> spaces: return spaces.Dict({ "player": spaces.Dict({ "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), "enemy": spaces.Dict({ "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), "ball": spaces.Dict({ "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), "score_player": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), "score_enemy": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), })
[docs] def image_space(self) -> spaces.Box: return spaces.Box( low=0, high=255, shape=(210, 160, 3), dtype=jnp.uint8 )
@partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: PongState, ) -> PongInfo: return PongInfo(time=state.step_counter) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: PongState, state: PongState): return (state.player_score - state.enemy_score) - ( previous_state.player_score - previous_state.enemy_score ) @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: PongState) -> bool: return jnp.logical_or( jnp.greater_equal(state.player_score, 21), jnp.greater_equal(state.enemy_score, 21), )
[docs] class PongRenderer(JAXGameRenderer): def __init__(self, consts: PongConstants = None): super().__init__(consts) self.consts = consts or PongConstants() self.config = render_utils.RendererConfig( game_dimensions=(210, 160), channels=3, #downscale=(84, 84) ) self.jr = render_utils.JaxRenderingUtils(self.config) # 1. Start from (possibly modded) asset config provided via constants final_asset_config = list(self.consts.ASSET_CONFIG) # 2. Create procedural assets using modded constants wall_sprite_top = _create_wall_sprite(self.consts, self.consts.WALL_TOP_HEIGHT) wall_sprite_bottom = _create_wall_sprite(self.consts, self.consts.WALL_BOTTOM_HEIGHT) # 3. Append procedural assets final_asset_config.append({'name': 'wall_top', 'type': 'procedural', 'data': wall_sprite_top}) final_asset_config.append({'name': 'wall_bottom', 'type': 'procedural', 'data': wall_sprite_bottom}) # 4. Bake assets once sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/pong" ( self.PALETTE, self.SHAPE_MASKS, self.BACKGROUND, self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(final_asset_config, sprite_path)
[docs] @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, self.consts.PLAYER_X, state.player_y, player_mask) enemy_mask = self.SHAPE_MASKS["enemy"] raster = self.jr.render_at(raster, self.consts.ENEMY_X, state.enemy_y, enemy_mask) ball_mask = self.SHAPE_MASKS["ball"] raster = self.jr.render_at(raster, state.ball_x, state.ball_y, ball_mask) # --- Stamp Walls and Score (using the same color/ID) --- score_color_tuple = self.consts.SCORE_COLOR # (236, 236, 236) score_id = self.COLOR_TO_ID[score_color_tuple] # Draw walls (using separate sprites for top and bottom) raster = self.jr.render_at(raster, 0, self.consts.WALL_TOP_Y, self.SHAPE_MASKS["wall_top"]) raster = self.jr.render_at(raster, 0, self.consts.WALL_BOTTOM_Y, self.SHAPE_MASKS["wall_bottom"]) # Stamp Score using the label utility player_digits = self.jr.int_to_digits(state.player_score, max_digits=2) enemy_digits = self.jr.int_to_digits(state.enemy_score, max_digits=2) # Note: The logic for single/double digits is complex for a jitted function. player_digit_masks = self.SHAPE_MASKS["player_digits"] # Assumes single color enemy_digit_masks = self.SHAPE_MASKS["enemy_digits"] # Assumes single color is_player_single_digit = state.player_score < 10 player_start_index = jax.lax.select(is_player_single_digit, 1, 0) player_num_to_render = jax.lax.select(is_player_single_digit, 1, 2) player_render_x = jax.lax.select(is_player_single_digit, 120 + 16 // 2, 120) raster = self.jr.render_label_selective(raster, player_render_x, 3, player_digits, player_digit_masks, player_start_index, player_num_to_render, spacing=16) is_enemy_single_digit = state.enemy_score < 10 enemy_start_index = jax.lax.select(is_enemy_single_digit, 1, 0) enemy_num_to_render = jax.lax.select(is_enemy_single_digit, 1, 2) enemy_render_x = jax.lax.select(is_enemy_single_digit, 10 + 16 // 2, 10) raster = self.jr.render_label_selective(raster, enemy_render_x, 3, enemy_digits, enemy_digit_masks, enemy_start_index, enemy_num_to_render, spacing=16) return self.jr.render_from_palette(raster, self.PALETTE)