Spaces¶
- class jaxatari.spaces.Box(low: float | ~numpy.ndarray | ~jax.Array, high: float | ~numpy.ndarray | ~jax.Array, shape: ~typing.Tuple[int, ...] | None = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]¶
Bases:
SpaceA jittable n-dimensional box space.
This space represents the Cartesian product of n closed intervals. Each interval has its own lower and upper bound.
It can be initialized in two ways: 1. With scalar bounds and an explicit shape, creating a box with uniform bounds.
Example: Box(low=0.0, high=1.0, shape=(3, 4))
With array-like bounds, where the shape is inferred from the bounds arrays. Example: Box(low=jnp.array([0., -1.]), high=jnp.array([1., 1.]))
- class jaxatari.spaces.Dict(spaces: dict)[source]¶
Bases:
SpaceA jittable dictionary of simpler jittable spaces. CRITICAL: Respects insertion order to align with Dataclass flattening.
- class jaxatari.spaces.Discrete(num_categories: int)[source]¶
Bases:
SpaceMinimal jittable class for discrete spaces.
- class jaxatari.spaces.Tuple(spaces: Sequence[Space])[source]¶
Bases:
SpaceA jittable tuple of simpler jittable spaces (Pytree container).
- jaxatari.spaces.get_object_space(n: int | None = None, screen_size=(210, 160), orientation_range=(0.0, 360.0), xy_low: float = 0.0) Dict[source]¶
Generates the standard space for an ObjectObservation. :param n: Number of objects. None (or 1) for scalars, >1 for arrays. :param screen_size: Tuple (height, width) for bounds (uses HWC for consistency). :param orientation_range: Tuple (min_orientation, max_orientation) for orientation bounds. :param xy_low: Lower bound for x/y (default 0). Use -1 when observations use -1 as an off-screen sentinel.