Spaces

class jaxatari.spaces.Box(low: float | ~numpy.ndarray | ~jax.Array, high: float | ~numpy.ndarray | ~jax.Array, shape: ~typing.Tuple[int, ...] | None = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]

Bases: Space

A jittable n-dimensional box space.

This space represents the Cartesian product of n closed intervals. Each interval has its own lower and upper bound.

It can be initialized in two ways: 1. With scalar bounds and an explicit shape, creating a box with uniform bounds.

Example: Box(low=0.0, high=1.0, shape=(3, 4))

  1. With array-like bounds, where the shape is inferred from the bounds arrays. Example: Box(low=jnp.array([0., -1.]), high=jnp.array([1., 1.]))

contains(x: Array) Array[source]

Check if a point x is contained within the box.

range() tuple[Array, Array][source]

Returns the lower and upper bounds of the space.

sample(key: Array) Array[source]

Generates a random sample from the space.

The sample is uniformly distributed over the box.

class jaxatari.spaces.Dict(spaces: dict)[source]

Bases: Space

A jittable dictionary of simpler jittable spaces. CRITICAL: Respects insertion order to align with Dataclass flattening.

contains(x: dict) Array[source]
sample(key: Array) OrderedDict[source]
class jaxatari.spaces.Discrete(num_categories: int)[source]

Bases: Space

Minimal jittable class for discrete spaces.

contains(x: Array) Array[source]

Check whether specific object is within space.

range() tuple[float, float][source]
sample(key: Array) Array[source]

Sample random action uniformly from set of categorical choices.

class jaxatari.spaces.Space[source]

Bases: object

Minimal jittable class for abstract spaces.

contains(x: Array) Any[source]
range()[source]
sample(key: Array) Array[source]
class jaxatari.spaces.Tuple(spaces: Sequence[Space])[source]

Bases: Space

A jittable tuple of simpler jittable spaces (Pytree container).

contains(x: tuple) Array[source]

Check whether the given Pytree is contained in the space.

sample(key: Array) tuple[source]

Sample a random tuple from all subspaces.

jaxatari.spaces.get_object_space(n: int | None = None, screen_size=(210, 160), orientation_range=(0.0, 360.0), xy_low: float = 0.0) Dict[source]

Generates the standard space for an ObjectObservation. :param n: Number of objects. None (or 1) for scalars, >1 for arrays. :param screen_size: Tuple (height, width) for bounds (uses HWC for consistency). :param orientation_range: Tuple (min_orientation, max_orientation) for orientation bounds. :param xy_low: Lower bound for x/y (default 0). Use -1 when observations use -1 as an off-screen sentinel.

jaxatari.spaces.stack_space(space: Space, stack_size: int) Space[source]

Recursively wraps a space or a Pytree of spaces to add a stacking dimension to each leaf space. Handles Box and Discrete spaces as leaves.