Source code for exciting_environments.mass_spring_damper.mass_spring_damper_env

from functools import partial
from typing import Callable

import numpy as np
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_structure
import jax_dataclasses as jdc
import chex
import diffrax

from exciting_environments import ClassicCoreEnvironment


[docs]class MassSpringDamper(ClassicCoreEnvironment): """ State Variables: ``['deflection', 'velocity']`` Action Variable: ``['force']`` Initial State: Unless chosen otherwise, deflection and velocity is set to zero. Example: >>> import jax >>> import jax.numpy as jnp >>> import exciting_environments as excenvs >>> from exciting_environments import GymWrapper >>> >>> # Create the environment >>> msd=excenv.MassSpringDamper(batch_size=4,action_constraints={"force":10}) >>> >>> # Use GymWrapper for Simulation (optional) >>> gym_msd=GymWrapper(env=msd) >>> >>> # Reset the environment with default initial values >>> gym_msd.reset() >>> >>> # Perform step >>> obs,reward,terminated,truncated = gym_msd.step(action=jnp.ones(4).reshape(-1,1)) >>> """ def __init__( self, batch_size: int = 8, physical_constraints: dict = None, action_constraints: dict = None, static_params: dict = None, control_state: list = None, solver=diffrax.Euler(), tau: float = 1e-4, ): """ Args: batch_size(int): Number of training examples utilized in one iteration. Default: 8 physical_constraints(dict): Constraints of the physical state of the environment. deflection(float): Deflection of the mass. Default: 10 velocity(float): Velocity of the mass. Default: 10 action_constraints(dict): Constraints of the action. force(float): Maximum force that can be applied to the system as action. Default: 20 static_params(dict): Parameters of environment which do not change during simulation. d(float): Damping constant. Default: 1 k(float): Spring constant. Default: 100 m(float): Mass of the oscillating object. Default: 1 control_state: TODO solver(diffrax.solver): Solver used to compute state for next step. tau(float): Duration of one control step in seconds. Default: 1e-4. Note: Attributes of physical_constraints, action_constraints and static_params can also be passed as jnp.Array with the length of the batch_size to set different values per batch. """ if not physical_constraints: physical_constraints = {"deflection": 10, "velocity": 10} if not action_constraints: action_constraints = {"force": 20} if not static_params: static_params = {"k": 100, "d": 1, "m": 1} if not control_state: control_state = [] self.control_state = control_state physical_constraints = self.PhysicalState(**physical_constraints) action_constraints = self.Action(**action_constraints) static_params = self.StaticParams(**static_params) super().__init__( batch_size, physical_constraints, action_constraints, static_params, tau=tau, solver=solver, )
[docs] @jdc.pytree_dataclass class PhysicalState: """Dataclass containing the physical state of the environment.""" deflection: jax.Array velocity: jax.Array
[docs] @jdc.pytree_dataclass class Additions: """Dataclass containing additional information for simulation.""" something: jax.Array
[docs] @jdc.pytree_dataclass class StaticParams: """Dataclass containing the static parameters of the environment.""" d: jax.Array k: jax.Array m: jax.Array
[docs] @jdc.pytree_dataclass class Action: """Dataclass containing the action, that can be applied to the environment.""" force: jax.Array
@partial(jax.jit, static_argnums=0) def _ode_solver_step(self, state, action, static_params): """Computes state by simulating one step. Args: state: The state from which to calculate state for the next step. action: The action to apply to the environment. static_params: Parameter of the environment, that do not change over time. Returns: state: The computed state after the one step simulation. """ physical_state = state.physical_state args = (action, static_params) def vector_field(t, y, args): deflection, velocity = y action, params = args d_velocity = (action[0] - params.d * velocity - params.k * deflection) / params.m d_deflection = velocity d_y = d_deflection, d_velocity # [0] return d_y term = diffrax.ODETerm(vector_field) t0 = 0 t1 = self.tau y0 = tuple([physical_state.deflection, physical_state.velocity]) env_state = self._solver.init(term, t0, t1, y0, args) y, _, _, env_state, _ = self._solver.step(term, t0, t1, y0, args, env_state, made_jump=False) deflection_k1 = y[0] velocity_k1 = y[1] with jdc.copy_and_mutate(state, validate=False) as new_state: new_state.physical_state = self.PhysicalState(deflection=deflection_k1, velocity=velocity_k1) return new_state @partial(jax.jit, static_argnums=[0, 4, 5]) def _ode_solver_simulate_ahead(self, init_state, actions, static_params, obs_stepsize, action_stepsize): """Computes states by simulating a trajectory with given actions.""" init_physical_state = init_state.physical_state args = (actions, static_params) def force(t, args): actions = args return actions[jnp.array(t / action_stepsize, int), 0] def vector_field(t, y, args): deflection, velocity = y actions, params = args d_velocity = (force(t, actions) - params.d * velocity - params.k * deflection) / params.m d_deflection = velocity d_y = d_deflection, d_velocity return d_y term = diffrax.ODETerm(vector_field) t0 = 0 t1 = action_stepsize * actions.shape[0] init_physical_state_array, _ = tree_flatten(init_physical_state) y0 = tuple(init_physical_state_array) saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 1 + int(t1 / obs_stepsize))) # sol = diffrax.diffeqsolve(term, self._solver, t0, t1, dt0=obs_stepsize, y0=y0, args=args, saveat=saveat) deflection_t = sol.ys[0] velocity_t = sol.ys[1] physical_states = self.PhysicalState(deflection=deflection_t, velocity=velocity_t) ref = self.PhysicalState( deflection=jnp.full(deflection_t.shape, jnp.nan), velocity=jnp.full(velocity_t.shape, jnp.nan) ) additions = None PRNGKey = None return self.State(physical_state=physical_states, PRNGKey=PRNGKey, additions=additions, reference=ref)
[docs] @partial(jax.jit, static_argnums=0) def init_state(self, env_properties, rng: chex.PRNGKey = None, vmap_helper=None): """Returns default initial state for all batches.""" if rng is None: phys = self.PhysicalState( deflection=0, velocity=0, ) subkey = None else: state_norm = jax.random.uniform(rng, minval=-1, maxval=1, shape=(2,)) phys = self.PhysicalState( deflection=state_norm[0] * env_properties.physical_constraints.deflection, velocity=state_norm[1] * env_properties.physical_constraints.velocity, ) key, subkey = jax.random.split(rng) additions = None # self.Optional(something=jnp.zeros(self.batch_size)) ref = self.PhysicalState(deflection=jnp.nan, velocity=jnp.nan) return self.State(physical_state=phys, PRNGKey=subkey, additions=additions, reference=ref)
@partial(jax.jit, static_argnums=0) def vmap_init_state(self, rng: chex.PRNGKey = None): return jax.vmap(self.init_state, in_axes=(self.in_axes_env_properties, 0, 0))( self.env_properties, rng, jnp.ones(self.batch_size) )
[docs] @partial(jax.jit, static_argnums=0) def generate_reward(self, state, action, env_properties): """Returns reward for one batch.""" reward = 0 for name in self.control_state: reward += -( ( (getattr(state.physical_state, name) - getattr(state.reference, name)) / (getattr(env_properties.physical_constraints, name)).astype(float) ) ** 2 ) return jnp.array([reward])
[docs] @partial(jax.jit, static_argnums=0) def generate_observation(self, state, env_properties): """Returns observation for one batch.""" physical_constraints = env_properties.physical_constraints obs = jnp.hstack( ( state.physical_state.deflection / physical_constraints.deflection, state.physical_state.velocity / physical_constraints.velocity, ) ) for name in self.control_state: obs = jnp.hstack( ( obs, (getattr(state.reference, name) / (getattr(physical_constraints, name)).astype(float)), ) ) return obs
[docs] @partial(jax.jit, static_argnums=0) def generate_truncated(self, state, env_properties): """Returns truncated information for one batch.""" obs = self.generate_observation(state, env_properties) return jnp.abs(obs) > 1
[docs] @partial(jax.jit, static_argnums=0) def generate_terminated(self, state, reward, env_properties): """Returns terminated information for one batch.""" return reward == 0
@property def obs_description(self): return np.hstack( [np.array(["deflection", "velocity"]), np.array([name + "_ref" for name in self.control_state])] ) @property def action_description(self): return np.array(["force"])
[docs] def reset(self, rng: chex.PRNGKey = None, initial_state: jdc.pytree_dataclass = None): """Resets environment to default or passed initial state.""" if initial_state is not None: assert tree_structure(self.vmap_init_state()) == tree_structure( initial_state ), f"initial_state should have the same dataclass structure as self.vmap_init_state()" state = initial_state else: state = self.vmap_init_state(rng) obs = jax.vmap( self.generate_observation, in_axes=(0, self.in_axes_env_properties), )(state, self.env_properties) return obs, state