from functools import partial
from typing import Callable
import numpy as np
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_structure
import jax_dataclasses as jdc
import diffrax
import chex
from exciting_environments import ClassicCoreEnvironment
[docs]class Pendulum(ClassicCoreEnvironment):
"""
State Variables:
``['theta', 'omega']``
Action Variable:
``['torque']``
Initial State:
Unless chosen otherwise, theta=pi and omega=0
Example:
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> import exciting_environments as excenvs
>>> from exciting_environments import GymWrapper
>>>
>>> # Create the environment
>>> pend=excenvs.Pendulum(batch_size=4, action_constraints={"torque": 10}, tau=2e-2)
>>>
>>> # Use GymWrapper for Simulation (optional)
>>> gym_pend=GymWrapper(env=pend)
>>>
>>> # Reset the environment with default initial values
>>> gym_pend.reset()
>>>
>>> # Perform step
>>> obs, reward, terminated, truncated = gym_pend.step(action=jnp.ones(4).reshape(-1,1))
>>>
"""
def __init__(
self,
batch_size: int = 8,
physical_constraints: dict = None,
action_constraints: dict = None,
static_params: dict = None,
control_state: list = None,
solver=diffrax.Euler(),
tau: float = 1e-4,
):
"""
Args:
batch_size (int): Number of parallel environment simulations. Default: 8
physical_constraints (dict): Constraints of the physical state of the environment.
theta (float): Rotation angle. Default: jnp.pi
omega (float): Angular velocity. Default: 10
action_constraints (dict): Constraints of the input/action.
torque (float): Maximum torque that can be applied to the system as an action. Default: 20
static_params (dict): Parameters of environment which do not change during simulation.
l (float): Length of the pendulum. Default: 1
m (float): Mass of the pendulum tip. Default: 1
g (float): Gravitational acceleration. Default: 9.81
control_state: TODO
solver (diffrax.solver): Solver used to compute state for next step.
tau (float): Duration of one control step in seconds. Default: 1e-4.
Note: Attributes of physical_constraints, action_constraints and static_params can also be
passed as jnp.Array with the length of the batch_size to set different values per batch.
"""
if not physical_constraints:
physical_constraints = {"theta": jnp.pi, "omega": 10}
if not action_constraints:
action_constraints = {"torque": 20}
if not static_params:
static_params = {"g": 9.81, "l": 2, "m": 1}
if not control_state:
control_state = []
self.control_state = control_state
physical_constraints = self.PhysicalState(**physical_constraints)
action_constraints = self.Action(**action_constraints)
static_params = self.StaticParams(**static_params)
super().__init__(
batch_size,
physical_constraints,
action_constraints,
static_params,
tau=tau,
solver=solver,
)
[docs] @jdc.pytree_dataclass
class PhysicalState:
"""Dataclass containing the physical state of the environment."""
theta: jax.Array
omega: jax.Array
[docs] @jdc.pytree_dataclass
class Additions:
"""Dataclass containing additional information for simulation."""
something: jax.Array
[docs] @jdc.pytree_dataclass
class StaticParams:
"""Dataclass containing the static parameters of the environment."""
g: jax.Array
l: jax.Array
m: jax.Array
[docs] @jdc.pytree_dataclass
class Action:
"""Dataclass containing the action, that can be applied to the environment."""
torque: jax.Array
@partial(jax.jit, static_argnums=0)
def _ode_solver_step(self, state, action, static_params):
"""Computes the next state by simulating one step.
Args:
state: The state from which to calculate state for the next step.
action: The action to apply to the environment.
static_params: Parameter of the environment, that do not change over time.
Returns:
next_state: The computed next state after the one step simulation.
"""
physical_state = state.physical_state
args = (action, static_params)
def vector_field(t, y, args):
theta, omega = y
action, params = args
d_omega = (action[0] + params.l * params.m * params.g * jnp.sin(theta)) / (params.m * (params.l) ** 2)
d_theta = omega
d_y = d_theta, d_omega
return d_y
term = diffrax.ODETerm(vector_field)
t0 = 0
t1 = self.tau
y0 = tuple([physical_state.theta, physical_state.omega])
env_state = self._solver.init(term, t0, t1, y0, args)
y, _, _, env_state, _ = self._solver.step(term, t0, t1, y0, args, env_state, made_jump=False)
theta_k1 = y[0]
omega_k1 = y[1]
theta_k1 = ((theta_k1 + jnp.pi) % (2 * jnp.pi)) - jnp.pi
with jdc.copy_and_mutate(state, validate=False) as new_state:
new_state.physical_state = self.PhysicalState(theta=theta_k1, omega=omega_k1)
return new_state
@partial(jax.jit, static_argnums=[0, 4, 5])
def _ode_solver_simulate_ahead(self, init_state, actions, static_params, obs_stepsize, action_stepsize):
"""Computes states by simulating a trajectory with given actions."""
init_physical_state = init_state.physical_state
args = (actions, static_params)
def force(t, args):
actions = args
return actions[jnp.array(t / action_stepsize, int), 0]
def vector_field(t, y, args):
theta, omega = y
actions, params = args
d_omega = (force(t, actions) + params.l * params.m * params.g * jnp.sin(theta)) / (
params.m * (params.l) ** 2
)
d_theta = omega
d_y = d_theta, d_omega
return d_y
term = diffrax.ODETerm(vector_field)
t0 = 0
t1 = action_stepsize * actions.shape[0]
init_physical_state_array, _ = tree_flatten(init_physical_state)
y0 = tuple(init_physical_state_array)
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 1 + int(t1 / obs_stepsize))) #
sol = diffrax.diffeqsolve(term, self._solver, t0, t1, dt0=obs_stepsize, y0=y0, args=args, saveat=saveat)
theta_t = sol.ys[0]
omega_t = sol.ys[1]
# keep theta between -pi and pi
theta_t = ((theta_t + jnp.pi) % (2 * jnp.pi)) - jnp.pi
physical_states = self.PhysicalState(theta=theta_t, omega=omega_t)
ref = self.PhysicalState(theta=jnp.nan, omega=jnp.nan)
additions = None
PRNGKey = None
return self.State(physical_state=physical_states, PRNGKey=PRNGKey, additions=additions, reference=ref)
[docs] @partial(jax.jit, static_argnums=0)
def init_state(self, env_properties, rng: chex.PRNGKey = None, vmap_helper=None):
"""Returns default initial state for all batches."""
if rng is None:
phys = self.PhysicalState(
theta=jnp.pi,
omega=0,
)
subkey = None
else:
state_norm = jax.random.uniform(rng, minval=-1, maxval=1, shape=(2,))
phys = self.PhysicalState(
theta=state_norm[0] * env_properties.physical_constraints.theta,
omega=state_norm[1] * env_properties.physical_constraints.omega,
)
key, subkey = jax.random.split(rng)
additions = None # self.Optional(something=jnp.zeros(self.batch_size))
ref = self.PhysicalState(theta=jnp.nan, omega=jnp.nan)
return self.State(physical_state=phys, PRNGKey=subkey, additions=additions, reference=ref)
@partial(jax.jit, static_argnums=0)
def vmap_init_state(self, rng: chex.PRNGKey = None):
return jax.vmap(self.init_state, in_axes=(self.in_axes_env_properties, 0, 0))(
self.env_properties, rng, jnp.ones(self.batch_size)
)
[docs] @partial(jax.jit, static_argnums=0)
def generate_reward(self, state, action, env_properties):
"""Returns reward for one batch."""
reward = 0
for name in self.control_state:
reward += -(
(
(getattr(state.physical_state, name) - getattr(state.reference, name))
/ (getattr(env_properties.physical_constraints, name)).astype(float)
)
** 2
)
return jnp.array([reward])
[docs] @partial(jax.jit, static_argnums=0)
def generate_observation(self, state, env_properties):
"""Returns observation for one batch."""
physical_constraints = env_properties.physical_constraints
obs = jnp.hstack(
(
state.physical_state.theta / physical_constraints.theta,
state.physical_state.omega / physical_constraints.omega,
)
)
for name in self.control_state:
obs = jnp.hstack(
(
obs,
(getattr(state.reference, name) / (getattr(physical_constraints, name)).astype(float)),
)
)
return obs
[docs] @partial(jax.jit, static_argnums=0)
def generate_truncated(self, state, env_properties):
"""Returns truncated information for one batch."""
obs = self.generate_observation(state, env_properties)
return jnp.abs(obs) > 1
[docs] @partial(jax.jit, static_argnums=0)
def generate_terminated(self, state, reward, env_properties):
"""Returns terminated information for one batch."""
return reward == 0
@property
def obs_description(self):
return np.hstack([np.array(["theta", "omega"]), np.array([name + "_ref" for name in self.control_state])])
@property
def action_description(self):
return np.array(["torque"])
[docs] def reset(self, rng: chex.PRNGKey = None, initial_state: jdc.pytree_dataclass = None):
"""Resets environment to default or passed initial state."""
if initial_state is not None:
assert tree_structure(self.vmap_init_state()) == tree_structure(
initial_state
), f"initial_state should have the same dataclass structure as self.vmap_init_state()"
state = initial_state
else:
state = self.vmap_init_state(rng)
obs = jax.vmap(
self.generate_observation,
in_axes=(0, self.in_axes_env_properties),
)(state, self.env_properties)
return obs, state