Spaces
- class jaxatari.spaces.Box(low: float | ~numpy.ndarray | ~jax.Array, high: float | ~numpy.ndarray | ~jax.Array, shape: ~typing.Tuple[int, ...] | None = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]
Bases:
SpaceA jittable n-dimensional box space.
This space represents the Cartesian product of n closed intervals. Each interval has its own lower and upper bound.
It can be initialized in two ways: 1. With scalar bounds and an explicit shape, creating a box with uniform bounds.
Example: Box(low=0.0, high=1.0, shape=(3, 4))
With array-like bounds, where the shape is inferred from the bounds arrays. Example: Box(low=jnp.array([0., -1.]), high=jnp.array([1., 1.]))
- class jaxatari.spaces.Dict(spaces: dict)[source]
Bases:
SpaceA jittable dictionary of simpler jittable spaces (Pytree container).
- class jaxatari.spaces.Discrete(num_categories: int)[source]
Bases:
SpaceMinimal jittable class for discrete spaces.