kos_sim.simulator

Wrapper around MuJoCo simulation.

  1"""Wrapper around MuJoCo simulation."""
  2
  3import math
  4import random
  5import time
  6from dataclasses import dataclass
  7from pathlib import Path
  8from typing import Literal, NotRequired, TypedDict, TypeVar
  9
 10import mujoco
 11import mujoco_viewer
 12import numpy as np
 13from kscale.web.gen.api import RobotURDFMetadataOutput
 14from mujoco_scenes.mjcf import load_mjmodel
 15
 16from kos_sim import logger
 17
 18T = TypeVar("T")
 19
 20
 21def _nn(value: T | None) -> T:
 22    if value is None:
 23        raise ValueError("Value is not set")
 24    return value
 25
 26
 27class ConfigureActuatorRequest(TypedDict):
 28    torque_enabled: NotRequired[bool]
 29    zero_position: NotRequired[float]
 30    kp: NotRequired[float]
 31    kd: NotRequired[float]
 32    max_torque: NotRequired[float]
 33
 34
 35@dataclass
 36class ActuatorState:
 37    position: float
 38    velocity: float
 39
 40
 41class ActuatorCommand(TypedDict):
 42    position: NotRequired[float]
 43    velocity: NotRequired[float]
 44    torque: NotRequired[float]
 45
 46
 47def get_integrator(integrator: str) -> mujoco.mjtIntegrator:
 48    match integrator.lower():
 49        case "euler":
 50            return mujoco.mjtIntegrator.mjINT_EULER
 51        case "implicit":
 52            return mujoco.mjtIntegrator.mjINT_IMPLICIT
 53        case "implicitfast":
 54            return mujoco.mjtIntegrator.mjINT_IMPLICITFAST
 55        case "rk4":
 56            return mujoco.mjtIntegrator.mjINT_RK4
 57        case _:
 58            raise ValueError(f"Invalid integrator: {integrator}")
 59
 60
 61class MujocoSimulator:
 62    def __init__(
 63        self,
 64        model_path: str | Path,
 65        model_metadata: RobotURDFMetadataOutput,
 66        dt: float = 0.001,
 67        gravity: bool = True,
 68        render_mode: Literal["window", "offscreen"] = "window",
 69        suspended: bool = False,
 70        start_height: float = 1.5,
 71        command_delay_min: float = 0.0,
 72        command_delay_max: float = 0.0,
 73        joint_pos_delta_noise: float = 0.0,
 74        joint_pos_noise: float = 0.0,
 75        joint_vel_noise: float = 0.0,
 76        pd_update_frequency: float = 100.0,
 77        mujoco_scene: str = "smooth",
 78        integrator: str = "implicitfast",
 79        camera: str | None = None,
 80        frame_width: int = 640,
 81        frame_height: int = 480,
 82    ) -> None:
 83        # Stores parameters.
 84        self._model_path = model_path
 85        self._metadata = model_metadata
 86        self._dt = dt
 87        self._gravity = gravity
 88        self._render_mode = render_mode
 89        self._suspended = suspended
 90        self._start_height = start_height
 91        self._command_delay_min = command_delay_min
 92        self._command_delay_max = command_delay_max
 93        self._joint_pos_delta_noise = math.radians(joint_pos_delta_noise)
 94        self._joint_pos_noise = math.radians(joint_pos_noise)
 95        self._joint_vel_noise = math.radians(joint_vel_noise)
 96        self._update_pd_delta = 1.0 / pd_update_frequency
 97        self._camera = camera
 98
 99        # Gets the sim decimation.
100        if (control_frequency := self._metadata.control_frequency) is None:
101            raise ValueError("Control frequency is not set")
102        self._control_frequency = float(control_frequency)
103        self._control_dt = 1.0 / self._control_frequency
104        self._sim_decimation = int(self._control_dt / self._dt)
105
106        # Gets the joint name mapping.
107        if self._metadata.joint_name_to_metadata is None:
108            raise ValueError("Joint name to metadata is not set")
109
110        # Gets the IDs, KPs, and KDs for each joint.
111        self._joint_name_to_id = {name: _nn(joint.id) for name, joint in self._metadata.joint_name_to_metadata.items()}
112        self._joint_name_to_kp: dict[str, float] = {
113            name: float(_nn(joint.kp)) for name, joint in self._metadata.joint_name_to_metadata.items()
114        }
115        self._joint_name_to_kd: dict[str, float] = {
116            name: float(_nn(joint.kd)) for name, joint in self._metadata.joint_name_to_metadata.items()
117        }
118        self._joint_name_to_max_torque: dict[str, float] = {}
119
120        # Gets the inverse mapping.
121        self._joint_id_to_name = {v: k for k, v in self._joint_name_to_id.items()}
122        if len(self._joint_name_to_id) != len(self._joint_id_to_name):
123            raise ValueError("Joint IDs are not unique!")
124
125        # Chooses some random deltas for the joint positions.
126        self._joint_name_to_pos_delta = {
127            name: random.uniform(-self._joint_pos_delta_noise, self._joint_pos_delta_noise)
128            for name in self._joint_name_to_id
129        }
130
131        # Load MuJoCo model and initialize data
132        logger.info("Loading model from %s", model_path)
133        self._model = load_mjmodel(model_path, mujoco_scene)
134        self._model.opt.timestep = self._dt
135        self._model.opt.integrator = get_integrator(integrator)
136        self._model.opt.solver = mujoco.mjtSolver.mjSOL_CG
137
138        self._data = mujoco.MjData(self._model)
139
140        logger.info("Joint ID to name: %s", self._joint_id_to_name)
141
142        if not self._gravity:
143            self._model.opt.gravity[2] = 0.0
144
145        # Initialize velocities and accelerations to zero
146        self._data.qpos[:3] = np.array([0.0, 0.0, self._start_height])
147        self._data.qpos[3:7] = np.array([0.0, 0.0, 0.0, 1.0])
148        self._data.qpos[7:] = np.zeros_like(self._data.qpos[7:])
149        self._data.qvel = np.zeros_like(self._data.qvel)
150        self._data.qacc = np.zeros_like(self._data.qacc)
151
152        # Important: Step simulation once to initialize internal structures
153        mujoco.mj_forward(self._model, self._data)
154        mujoco.mj_step(self._model, self._data)
155
156        # Setup viewer after initial step
157        self._render_enabled = self._render_mode == "window"
158        self._viewer = mujoco_viewer.MujocoViewer(
159            self._model,
160            self._data,
161            mode=self._render_mode,
162            width=frame_width,
163            height=frame_height,
164        )
165
166        if self._camera is not None:
167            camera_obj = self._model.camera(self._camera)
168            self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
169            self._viewer.cam.trackbodyid = camera_obj.id
170
171        # Cache lookups after initialization
172        self._sensor_name_to_id = {self._model.sensor(i).name: i for i in range(self._model.nsensor)}
173        logger.debug("Sensor IDs: %s", self._sensor_name_to_id)
174
175        self._actuator_name_to_id = {self._model.actuator(i).name: i for i in range(self._model.nu)}
176        logger.debug("Actuator IDs: %s", self._actuator_name_to_id)
177
178        # There is an important distinction between actuator IDs and joint IDs.
179        # joint IDs should be at the kos layer, where the canonical IDs are assigned (see docs.kscale.dev)
180        # but actuator IDs are at the mujoco layer, where the actuators actually get mapped.
181        logger.debug("Joint ID to name: %s", self._joint_id_to_name)
182        self._joint_id_to_actuator_id = {
183            joint_id: self._actuator_name_to_id[f"{name}_ctrl"] for joint_id, name in self._joint_id_to_name.items()
184        }
185        self._actuator_id_to_joint_id = {
186            actuator_id: joint_id for joint_id, actuator_id in self._joint_id_to_actuator_id.items()
187        }
188
189        # Add control parameters
190        self._sim_time = time.time()
191        self._current_commands: dict[str, ActuatorCommand] = {
192            name: {"position": 0.0, "velocity": 0.0, "torque": 0.0} for name in self._joint_name_to_id
193        }
194        self._next_commands: dict[str, tuple[ActuatorCommand, float]] = {}
195
196    async def step(self) -> None:
197        """Execute one step of the simulation."""
198        self._sim_time += self._dt
199
200        # Process commands that are ready to be applied
201        commands_to_remove = []
202        for name, (target_command, application_time) in self._next_commands.items():
203            if self._sim_time >= application_time:
204                self._current_commands[name] = target_command
205                commands_to_remove.append(name)
206
207        # Remove processed commands
208        if commands_to_remove:
209            for name in commands_to_remove:
210                self._next_commands.pop(name)
211
212        mujoco.mj_forward(self._model, self._data)
213
214        # Sets the ctrl values from the current commands.
215        for name, target_command in self._current_commands.items():
216            joint_id = self._joint_name_to_id[name]
217            actuator_id = self._joint_id_to_actuator_id[joint_id]
218            kp = self._joint_name_to_kp[name]
219            kd = self._joint_name_to_kd[name]
220            current_position = self._data.joint(name).qpos
221            current_velocity = self._data.joint(name).qvel
222            target_torque = (
223                kp * (target_command["position"] - current_position)
224                + kd * (target_command["velocity"] - current_velocity)
225                + target_command["torque"]
226            )
227            if (max_torque := self._joint_name_to_max_torque.get(name)) is not None:
228                target_torque = np.clip(target_torque, -max_torque, max_torque)
229            logger.debug("Setting ctrl for actuator %s to %f", actuator_id, target_torque)
230            self._data.ctrl[actuator_id] = target_torque
231
232        # Step physics - allow other coroutines to run during computation
233
234        # for some reason running forward before step makes it more stable.
235        # It possibly computes some values that are needed for the step.
236        mujoco.mj_forward(self._model, self._data)
237        mujoco.mj_step(self._model, self._data)
238        if self._suspended:
239            # Find the root joint (floating_base)
240            for i in range(self._model.njnt):
241                if self._model.jnt_type[i] == mujoco.mjtJoint.mjJNT_FREE:
242                    self._data.qpos[i : i + 7] = [0.0, 0.0, self._start_height, 0.0, 0.0, 0.0, 1.0]
243                    self._data.qvel[i : i + 6] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
244                    break
245
246        return self._data
247
248    async def render(self) -> None:
249        """Render the simulation asynchronously."""
250        if self._render_enabled:
251            self._viewer.render()
252
253    async def capture_frame(self, camid: int = -1, depth: bool = False) -> tuple[np.ndarray, np.ndarray | None]:
254        """Capture a frame from the simulation using read_pixels.
255
256        Args:
257            camid: Camera ID to use (-1 for free camera)
258            depth: Whether to return depth information
259
260        Returns:
261            RGB image array (and optionally depth array) if depth=True
262        """
263        if self._render_mode != "offscreen" and self._render_enabled:
264            logger.warning("Capturing frames is more efficient in offscreen mode")
265
266        if camid is not None:
267            if camid == -1:
268                self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
269            else:
270                self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
271                self._viewer.cam.fixedcamid = camid
272
273        if depth:
274            rgb, depth_img = self._viewer.read_pixels(depth=True)
275            return rgb, depth_img
276        else:
277            rgb = self._viewer.read_pixels()
278            return rgb, None
279
280    async def get_sensor_data(self, name: str) -> np.ndarray:
281        """Get data from a named sensor."""
282        if name not in self._sensor_name_to_id:
283            raise KeyError(f"Sensor '{name}' not found")
284        sensor_id = self._sensor_name_to_id[name]
285        return self._data.sensor(sensor_id).data.copy()
286
287    async def get_actuator_state(self, joint_id: int) -> ActuatorState:
288        """Get current state of an actuator using real joint ID."""
289        if joint_id not in self._joint_id_to_name:
290            raise KeyError(f"Joint ID {joint_id} not found in config mappings")
291
292        joint_name = self._joint_id_to_name[joint_id]
293        joint_data = self._data.joint(joint_name)
294
295        return ActuatorState(
296            position=float(joint_data.qpos)
297            + self._joint_name_to_pos_delta[joint_name]
298            + random.uniform(-self._joint_pos_noise, self._joint_pos_noise),
299            velocity=float(joint_data.qvel) + random.uniform(-self._joint_vel_noise, self._joint_vel_noise),
300        )
301
302    async def command_actuators(self, commands: dict[int, ActuatorCommand]) -> None:
303        """Command multiple actuators at once using real joint IDs."""
304        for joint_id, command in commands.items():
305            # Translate real joint ID to MuJoCo joint name
306            if joint_id not in self._joint_id_to_name:
307                logger.warning("Joint ID %d not found in config mappings", joint_id)
308                continue
309
310            joint_name = self._joint_id_to_name[joint_id]
311            actuator_name = f"{joint_name}_ctrl"
312            if actuator_name not in self._actuator_name_to_id:
313                logger.warning("Joint %s not found in MuJoCo model", actuator_name)
314                continue
315
316            # Calculate random delay and application time
317            delay = np.random.uniform(self._command_delay_min, self._command_delay_max)
318            application_time = self._sim_time + delay
319
320            self._next_commands[joint_name] = (command, application_time)
321
322    async def configure_actuator(self, joint_id: int, configuration: ConfigureActuatorRequest) -> None:
323        """Configure an actuator using real joint ID."""
324        if joint_id not in self._joint_id_to_actuator_id:
325            raise KeyError(
326                f"Joint ID {joint_id} not found in config mappings. "
327                f"The available joint IDs are {self._joint_id_to_actuator_id.keys()}"
328            )
329
330        joint_name = self._joint_id_to_name[joint_id]
331        if "kp" in configuration:
332            self._joint_name_to_kp[joint_name] = configuration["kp"]
333        if "kd" in configuration:
334            self._joint_name_to_kd[joint_name] = configuration["kd"]
335        if "max_torque" in configuration:
336            self._joint_name_to_max_torque[joint_name] = configuration["max_torque"]
337
338    @property
339    def sim_time(self) -> float:
340        return self._sim_time
341
342    async def reset(
343        self,
344        xyz: tuple[float, float, float] | None = None,
345        quat: tuple[float, float, float, float] | None = None,
346        joint_pos: dict[str, float] | None = None,
347        joint_vel: dict[str, float] | None = None,
348    ) -> None:
349        """Reset simulation to specified or default state."""
350        self._next_commands.clear()
351
352        mujoco.mj_resetData(self._model, self._data)
353
354        # Resets qpos.
355        qpos = np.zeros_like(self._data.qpos)
356        qpos[:3] = np.array([0.0, 0.0, self._start_height] if xyz is None else xyz)
357        qpos[3:7] = np.array([0.0, 0.0, 0.0, 1.0] if quat is None else quat)
358        qpos[7:] = np.zeros_like(self._data.qpos[7:])
359        if joint_pos is not None:
360            for joint_name, position in joint_pos.items():
361                self._data.joint(joint_name).qpos = position
362
363        # Resets qvel.
364        qvel = np.zeros_like(self._data.qvel)
365        if joint_vel is not None:
366            for joint_name, velocity in joint_vel.items():
367                self._data.joint(joint_name).qvel = velocity
368
369        # Resets qacc.
370        qacc = np.zeros_like(self._data.qacc)
371
372        # Runs one step.
373        self._data.qpos[:] = qpos
374        self._data.qvel[:] = qvel
375        self._data.qacc[:] = qacc
376        mujoco.mj_forward(self._model, self._data)
377
378    async def close(self) -> None:
379        """Clean up simulation resources."""
380        if self._viewer is not None:
381            try:
382                self._viewer.close()
383            except Exception as e:
384                logger.error("Error closing viewer: %s", e)
385            self._viewer = None
386
387    @property
388    def timestep(self) -> float:
389        return self._model.opt.timestep
class ConfigureActuatorRequest(typing.TypedDict):
28class ConfigureActuatorRequest(TypedDict):
29    torque_enabled: NotRequired[bool]
30    zero_position: NotRequired[float]
31    kp: NotRequired[float]
32    kd: NotRequired[float]
33    max_torque: NotRequired[float]
torque_enabled: NotRequired[bool]
zero_position: NotRequired[float]
kp: NotRequired[float]
kd: NotRequired[float]
max_torque: NotRequired[float]
@dataclass
class ActuatorState:
36@dataclass
37class ActuatorState:
38    position: float
39    velocity: float
ActuatorState(position: float, velocity: float)
position: float
velocity: float
class ActuatorCommand(typing.TypedDict):
42class ActuatorCommand(TypedDict):
43    position: NotRequired[float]
44    velocity: NotRequired[float]
45    torque: NotRequired[float]
position: NotRequired[float]
velocity: NotRequired[float]
torque: NotRequired[float]
def get_integrator(integrator: str) -> mujoco._enums.mjtIntegrator:
48def get_integrator(integrator: str) -> mujoco.mjtIntegrator:
49    match integrator.lower():
50        case "euler":
51            return mujoco.mjtIntegrator.mjINT_EULER
52        case "implicit":
53            return mujoco.mjtIntegrator.mjINT_IMPLICIT
54        case "implicitfast":
55            return mujoco.mjtIntegrator.mjINT_IMPLICITFAST
56        case "rk4":
57            return mujoco.mjtIntegrator.mjINT_RK4
58        case _:
59            raise ValueError(f"Invalid integrator: {integrator}")
class MujocoSimulator:
 62class MujocoSimulator:
 63    def __init__(
 64        self,
 65        model_path: str | Path,
 66        model_metadata: RobotURDFMetadataOutput,
 67        dt: float = 0.001,
 68        gravity: bool = True,
 69        render_mode: Literal["window", "offscreen"] = "window",
 70        suspended: bool = False,
 71        start_height: float = 1.5,
 72        command_delay_min: float = 0.0,
 73        command_delay_max: float = 0.0,
 74        joint_pos_delta_noise: float = 0.0,
 75        joint_pos_noise: float = 0.0,
 76        joint_vel_noise: float = 0.0,
 77        pd_update_frequency: float = 100.0,
 78        mujoco_scene: str = "smooth",
 79        integrator: str = "implicitfast",
 80        camera: str | None = None,
 81        frame_width: int = 640,
 82        frame_height: int = 480,
 83    ) -> None:
 84        # Stores parameters.
 85        self._model_path = model_path
 86        self._metadata = model_metadata
 87        self._dt = dt
 88        self._gravity = gravity
 89        self._render_mode = render_mode
 90        self._suspended = suspended
 91        self._start_height = start_height
 92        self._command_delay_min = command_delay_min
 93        self._command_delay_max = command_delay_max
 94        self._joint_pos_delta_noise = math.radians(joint_pos_delta_noise)
 95        self._joint_pos_noise = math.radians(joint_pos_noise)
 96        self._joint_vel_noise = math.radians(joint_vel_noise)
 97        self._update_pd_delta = 1.0 / pd_update_frequency
 98        self._camera = camera
 99
100        # Gets the sim decimation.
101        if (control_frequency := self._metadata.control_frequency) is None:
102            raise ValueError("Control frequency is not set")
103        self._control_frequency = float(control_frequency)
104        self._control_dt = 1.0 / self._control_frequency
105        self._sim_decimation = int(self._control_dt / self._dt)
106
107        # Gets the joint name mapping.
108        if self._metadata.joint_name_to_metadata is None:
109            raise ValueError("Joint name to metadata is not set")
110
111        # Gets the IDs, KPs, and KDs for each joint.
112        self._joint_name_to_id = {name: _nn(joint.id) for name, joint in self._metadata.joint_name_to_metadata.items()}
113        self._joint_name_to_kp: dict[str, float] = {
114            name: float(_nn(joint.kp)) for name, joint in self._metadata.joint_name_to_metadata.items()
115        }
116        self._joint_name_to_kd: dict[str, float] = {
117            name: float(_nn(joint.kd)) for name, joint in self._metadata.joint_name_to_metadata.items()
118        }
119        self._joint_name_to_max_torque: dict[str, float] = {}
120
121        # Gets the inverse mapping.
122        self._joint_id_to_name = {v: k for k, v in self._joint_name_to_id.items()}
123        if len(self._joint_name_to_id) != len(self._joint_id_to_name):
124            raise ValueError("Joint IDs are not unique!")
125
126        # Chooses some random deltas for the joint positions.
127        self._joint_name_to_pos_delta = {
128            name: random.uniform(-self._joint_pos_delta_noise, self._joint_pos_delta_noise)
129            for name in self._joint_name_to_id
130        }
131
132        # Load MuJoCo model and initialize data
133        logger.info("Loading model from %s", model_path)
134        self._model = load_mjmodel(model_path, mujoco_scene)
135        self._model.opt.timestep = self._dt
136        self._model.opt.integrator = get_integrator(integrator)
137        self._model.opt.solver = mujoco.mjtSolver.mjSOL_CG
138
139        self._data = mujoco.MjData(self._model)
140
141        logger.info("Joint ID to name: %s", self._joint_id_to_name)
142
143        if not self._gravity:
144            self._model.opt.gravity[2] = 0.0
145
146        # Initialize velocities and accelerations to zero
147        self._data.qpos[:3] = np.array([0.0, 0.0, self._start_height])
148        self._data.qpos[3:7] = np.array([0.0, 0.0, 0.0, 1.0])
149        self._data.qpos[7:] = np.zeros_like(self._data.qpos[7:])
150        self._data.qvel = np.zeros_like(self._data.qvel)
151        self._data.qacc = np.zeros_like(self._data.qacc)
152
153        # Important: Step simulation once to initialize internal structures
154        mujoco.mj_forward(self._model, self._data)
155        mujoco.mj_step(self._model, self._data)
156
157        # Setup viewer after initial step
158        self._render_enabled = self._render_mode == "window"
159        self._viewer = mujoco_viewer.MujocoViewer(
160            self._model,
161            self._data,
162            mode=self._render_mode,
163            width=frame_width,
164            height=frame_height,
165        )
166
167        if self._camera is not None:
168            camera_obj = self._model.camera(self._camera)
169            self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
170            self._viewer.cam.trackbodyid = camera_obj.id
171
172        # Cache lookups after initialization
173        self._sensor_name_to_id = {self._model.sensor(i).name: i for i in range(self._model.nsensor)}
174        logger.debug("Sensor IDs: %s", self._sensor_name_to_id)
175
176        self._actuator_name_to_id = {self._model.actuator(i).name: i for i in range(self._model.nu)}
177        logger.debug("Actuator IDs: %s", self._actuator_name_to_id)
178
179        # There is an important distinction between actuator IDs and joint IDs.
180        # joint IDs should be at the kos layer, where the canonical IDs are assigned (see docs.kscale.dev)
181        # but actuator IDs are at the mujoco layer, where the actuators actually get mapped.
182        logger.debug("Joint ID to name: %s", self._joint_id_to_name)
183        self._joint_id_to_actuator_id = {
184            joint_id: self._actuator_name_to_id[f"{name}_ctrl"] for joint_id, name in self._joint_id_to_name.items()
185        }
186        self._actuator_id_to_joint_id = {
187            actuator_id: joint_id for joint_id, actuator_id in self._joint_id_to_actuator_id.items()
188        }
189
190        # Add control parameters
191        self._sim_time = time.time()
192        self._current_commands: dict[str, ActuatorCommand] = {
193            name: {"position": 0.0, "velocity": 0.0, "torque": 0.0} for name in self._joint_name_to_id
194        }
195        self._next_commands: dict[str, tuple[ActuatorCommand, float]] = {}
196
197    async def step(self) -> None:
198        """Execute one step of the simulation."""
199        self._sim_time += self._dt
200
201        # Process commands that are ready to be applied
202        commands_to_remove = []
203        for name, (target_command, application_time) in self._next_commands.items():
204            if self._sim_time >= application_time:
205                self._current_commands[name] = target_command
206                commands_to_remove.append(name)
207
208        # Remove processed commands
209        if commands_to_remove:
210            for name in commands_to_remove:
211                self._next_commands.pop(name)
212
213        mujoco.mj_forward(self._model, self._data)
214
215        # Sets the ctrl values from the current commands.
216        for name, target_command in self._current_commands.items():
217            joint_id = self._joint_name_to_id[name]
218            actuator_id = self._joint_id_to_actuator_id[joint_id]
219            kp = self._joint_name_to_kp[name]
220            kd = self._joint_name_to_kd[name]
221            current_position = self._data.joint(name).qpos
222            current_velocity = self._data.joint(name).qvel
223            target_torque = (
224                kp * (target_command["position"] - current_position)
225                + kd * (target_command["velocity"] - current_velocity)
226                + target_command["torque"]
227            )
228            if (max_torque := self._joint_name_to_max_torque.get(name)) is not None:
229                target_torque = np.clip(target_torque, -max_torque, max_torque)
230            logger.debug("Setting ctrl for actuator %s to %f", actuator_id, target_torque)
231            self._data.ctrl[actuator_id] = target_torque
232
233        # Step physics - allow other coroutines to run during computation
234
235        # for some reason running forward before step makes it more stable.
236        # It possibly computes some values that are needed for the step.
237        mujoco.mj_forward(self._model, self._data)
238        mujoco.mj_step(self._model, self._data)
239        if self._suspended:
240            # Find the root joint (floating_base)
241            for i in range(self._model.njnt):
242                if self._model.jnt_type[i] == mujoco.mjtJoint.mjJNT_FREE:
243                    self._data.qpos[i : i + 7] = [0.0, 0.0, self._start_height, 0.0, 0.0, 0.0, 1.0]
244                    self._data.qvel[i : i + 6] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
245                    break
246
247        return self._data
248
249    async def render(self) -> None:
250        """Render the simulation asynchronously."""
251        if self._render_enabled:
252            self._viewer.render()
253
254    async def capture_frame(self, camid: int = -1, depth: bool = False) -> tuple[np.ndarray, np.ndarray | None]:
255        """Capture a frame from the simulation using read_pixels.
256
257        Args:
258            camid: Camera ID to use (-1 for free camera)
259            depth: Whether to return depth information
260
261        Returns:
262            RGB image array (and optionally depth array) if depth=True
263        """
264        if self._render_mode != "offscreen" and self._render_enabled:
265            logger.warning("Capturing frames is more efficient in offscreen mode")
266
267        if camid is not None:
268            if camid == -1:
269                self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
270            else:
271                self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
272                self._viewer.cam.fixedcamid = camid
273
274        if depth:
275            rgb, depth_img = self._viewer.read_pixels(depth=True)
276            return rgb, depth_img
277        else:
278            rgb = self._viewer.read_pixels()
279            return rgb, None
280
281    async def get_sensor_data(self, name: str) -> np.ndarray:
282        """Get data from a named sensor."""
283        if name not in self._sensor_name_to_id:
284            raise KeyError(f"Sensor '{name}' not found")
285        sensor_id = self._sensor_name_to_id[name]
286        return self._data.sensor(sensor_id).data.copy()
287
288    async def get_actuator_state(self, joint_id: int) -> ActuatorState:
289        """Get current state of an actuator using real joint ID."""
290        if joint_id not in self._joint_id_to_name:
291            raise KeyError(f"Joint ID {joint_id} not found in config mappings")
292
293        joint_name = self._joint_id_to_name[joint_id]
294        joint_data = self._data.joint(joint_name)
295
296        return ActuatorState(
297            position=float(joint_data.qpos)
298            + self._joint_name_to_pos_delta[joint_name]
299            + random.uniform(-self._joint_pos_noise, self._joint_pos_noise),
300            velocity=float(joint_data.qvel) + random.uniform(-self._joint_vel_noise, self._joint_vel_noise),
301        )
302
303    async def command_actuators(self, commands: dict[int, ActuatorCommand]) -> None:
304        """Command multiple actuators at once using real joint IDs."""
305        for joint_id, command in commands.items():
306            # Translate real joint ID to MuJoCo joint name
307            if joint_id not in self._joint_id_to_name:
308                logger.warning("Joint ID %d not found in config mappings", joint_id)
309                continue
310
311            joint_name = self._joint_id_to_name[joint_id]
312            actuator_name = f"{joint_name}_ctrl"
313            if actuator_name not in self._actuator_name_to_id:
314                logger.warning("Joint %s not found in MuJoCo model", actuator_name)
315                continue
316
317            # Calculate random delay and application time
318            delay = np.random.uniform(self._command_delay_min, self._command_delay_max)
319            application_time = self._sim_time + delay
320
321            self._next_commands[joint_name] = (command, application_time)
322
323    async def configure_actuator(self, joint_id: int, configuration: ConfigureActuatorRequest) -> None:
324        """Configure an actuator using real joint ID."""
325        if joint_id not in self._joint_id_to_actuator_id:
326            raise KeyError(
327                f"Joint ID {joint_id} not found in config mappings. "
328                f"The available joint IDs are {self._joint_id_to_actuator_id.keys()}"
329            )
330
331        joint_name = self._joint_id_to_name[joint_id]
332        if "kp" in configuration:
333            self._joint_name_to_kp[joint_name] = configuration["kp"]
334        if "kd" in configuration:
335            self._joint_name_to_kd[joint_name] = configuration["kd"]
336        if "max_torque" in configuration:
337            self._joint_name_to_max_torque[joint_name] = configuration["max_torque"]
338
339    @property
340    def sim_time(self) -> float:
341        return self._sim_time
342
343    async def reset(
344        self,
345        xyz: tuple[float, float, float] | None = None,
346        quat: tuple[float, float, float, float] | None = None,
347        joint_pos: dict[str, float] | None = None,
348        joint_vel: dict[str, float] | None = None,
349    ) -> None:
350        """Reset simulation to specified or default state."""
351        self._next_commands.clear()
352
353        mujoco.mj_resetData(self._model, self._data)
354
355        # Resets qpos.
356        qpos = np.zeros_like(self._data.qpos)
357        qpos[:3] = np.array([0.0, 0.0, self._start_height] if xyz is None else xyz)
358        qpos[3:7] = np.array([0.0, 0.0, 0.0, 1.0] if quat is None else quat)
359        qpos[7:] = np.zeros_like(self._data.qpos[7:])
360        if joint_pos is not None:
361            for joint_name, position in joint_pos.items():
362                self._data.joint(joint_name).qpos = position
363
364        # Resets qvel.
365        qvel = np.zeros_like(self._data.qvel)
366        if joint_vel is not None:
367            for joint_name, velocity in joint_vel.items():
368                self._data.joint(joint_name).qvel = velocity
369
370        # Resets qacc.
371        qacc = np.zeros_like(self._data.qacc)
372
373        # Runs one step.
374        self._data.qpos[:] = qpos
375        self._data.qvel[:] = qvel
376        self._data.qacc[:] = qacc
377        mujoco.mj_forward(self._model, self._data)
378
379    async def close(self) -> None:
380        """Clean up simulation resources."""
381        if self._viewer is not None:
382            try:
383                self._viewer.close()
384            except Exception as e:
385                logger.error("Error closing viewer: %s", e)
386            self._viewer = None
387
388    @property
389    def timestep(self) -> float:
390        return self._model.opt.timestep
MujocoSimulator( model_path: str | pathlib.Path, model_metadata: kscale.web.gen.api.RobotURDFMetadataOutput, dt: float = 0.001, gravity: bool = True, render_mode: Literal['window', 'offscreen'] = 'window', suspended: bool = False, start_height: float = 1.5, command_delay_min: float = 0.0, command_delay_max: float = 0.0, joint_pos_delta_noise: float = 0.0, joint_pos_noise: float = 0.0, joint_vel_noise: float = 0.0, pd_update_frequency: float = 100.0, mujoco_scene: str = 'smooth', integrator: str = 'implicitfast', camera: str | None = None, frame_width: int = 640, frame_height: int = 480)
 63    def __init__(
 64        self,
 65        model_path: str | Path,
 66        model_metadata: RobotURDFMetadataOutput,
 67        dt: float = 0.001,
 68        gravity: bool = True,
 69        render_mode: Literal["window", "offscreen"] = "window",
 70        suspended: bool = False,
 71        start_height: float = 1.5,
 72        command_delay_min: float = 0.0,
 73        command_delay_max: float = 0.0,
 74        joint_pos_delta_noise: float = 0.0,
 75        joint_pos_noise: float = 0.0,
 76        joint_vel_noise: float = 0.0,
 77        pd_update_frequency: float = 100.0,
 78        mujoco_scene: str = "smooth",
 79        integrator: str = "implicitfast",
 80        camera: str | None = None,
 81        frame_width: int = 640,
 82        frame_height: int = 480,
 83    ) -> None:
 84        # Stores parameters.
 85        self._model_path = model_path
 86        self._metadata = model_metadata
 87        self._dt = dt
 88        self._gravity = gravity
 89        self._render_mode = render_mode
 90        self._suspended = suspended
 91        self._start_height = start_height
 92        self._command_delay_min = command_delay_min
 93        self._command_delay_max = command_delay_max
 94        self._joint_pos_delta_noise = math.radians(joint_pos_delta_noise)
 95        self._joint_pos_noise = math.radians(joint_pos_noise)
 96        self._joint_vel_noise = math.radians(joint_vel_noise)
 97        self._update_pd_delta = 1.0 / pd_update_frequency
 98        self._camera = camera
 99
100        # Gets the sim decimation.
101        if (control_frequency := self._metadata.control_frequency) is None:
102            raise ValueError("Control frequency is not set")
103        self._control_frequency = float(control_frequency)
104        self._control_dt = 1.0 / self._control_frequency
105        self._sim_decimation = int(self._control_dt / self._dt)
106
107        # Gets the joint name mapping.
108        if self._metadata.joint_name_to_metadata is None:
109            raise ValueError("Joint name to metadata is not set")
110
111        # Gets the IDs, KPs, and KDs for each joint.
112        self._joint_name_to_id = {name: _nn(joint.id) for name, joint in self._metadata.joint_name_to_metadata.items()}
113        self._joint_name_to_kp: dict[str, float] = {
114            name: float(_nn(joint.kp)) for name, joint in self._metadata.joint_name_to_metadata.items()
115        }
116        self._joint_name_to_kd: dict[str, float] = {
117            name: float(_nn(joint.kd)) for name, joint in self._metadata.joint_name_to_metadata.items()
118        }
119        self._joint_name_to_max_torque: dict[str, float] = {}
120
121        # Gets the inverse mapping.
122        self._joint_id_to_name = {v: k for k, v in self._joint_name_to_id.items()}
123        if len(self._joint_name_to_id) != len(self._joint_id_to_name):
124            raise ValueError("Joint IDs are not unique!")
125
126        # Chooses some random deltas for the joint positions.
127        self._joint_name_to_pos_delta = {
128            name: random.uniform(-self._joint_pos_delta_noise, self._joint_pos_delta_noise)
129            for name in self._joint_name_to_id
130        }
131
132        # Load MuJoCo model and initialize data
133        logger.info("Loading model from %s", model_path)
134        self._model = load_mjmodel(model_path, mujoco_scene)
135        self._model.opt.timestep = self._dt
136        self._model.opt.integrator = get_integrator(integrator)
137        self._model.opt.solver = mujoco.mjtSolver.mjSOL_CG
138
139        self._data = mujoco.MjData(self._model)
140
141        logger.info("Joint ID to name: %s", self._joint_id_to_name)
142
143        if not self._gravity:
144            self._model.opt.gravity[2] = 0.0
145
146        # Initialize velocities and accelerations to zero
147        self._data.qpos[:3] = np.array([0.0, 0.0, self._start_height])
148        self._data.qpos[3:7] = np.array([0.0, 0.0, 0.0, 1.0])
149        self._data.qpos[7:] = np.zeros_like(self._data.qpos[7:])
150        self._data.qvel = np.zeros_like(self._data.qvel)
151        self._data.qacc = np.zeros_like(self._data.qacc)
152
153        # Important: Step simulation once to initialize internal structures
154        mujoco.mj_forward(self._model, self._data)
155        mujoco.mj_step(self._model, self._data)
156
157        # Setup viewer after initial step
158        self._render_enabled = self._render_mode == "window"
159        self._viewer = mujoco_viewer.MujocoViewer(
160            self._model,
161            self._data,
162            mode=self._render_mode,
163            width=frame_width,
164            height=frame_height,
165        )
166
167        if self._camera is not None:
168            camera_obj = self._model.camera(self._camera)
169            self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
170            self._viewer.cam.trackbodyid = camera_obj.id
171
172        # Cache lookups after initialization
173        self._sensor_name_to_id = {self._model.sensor(i).name: i for i in range(self._model.nsensor)}
174        logger.debug("Sensor IDs: %s", self._sensor_name_to_id)
175
176        self._actuator_name_to_id = {self._model.actuator(i).name: i for i in range(self._model.nu)}
177        logger.debug("Actuator IDs: %s", self._actuator_name_to_id)
178
179        # There is an important distinction between actuator IDs and joint IDs.
180        # joint IDs should be at the kos layer, where the canonical IDs are assigned (see docs.kscale.dev)
181        # but actuator IDs are at the mujoco layer, where the actuators actually get mapped.
182        logger.debug("Joint ID to name: %s", self._joint_id_to_name)
183        self._joint_id_to_actuator_id = {
184            joint_id: self._actuator_name_to_id[f"{name}_ctrl"] for joint_id, name in self._joint_id_to_name.items()
185        }
186        self._actuator_id_to_joint_id = {
187            actuator_id: joint_id for joint_id, actuator_id in self._joint_id_to_actuator_id.items()
188        }
189
190        # Add control parameters
191        self._sim_time = time.time()
192        self._current_commands: dict[str, ActuatorCommand] = {
193            name: {"position": 0.0, "velocity": 0.0, "torque": 0.0} for name in self._joint_name_to_id
194        }
195        self._next_commands: dict[str, tuple[ActuatorCommand, float]] = {}
async def step(self) -> None:
197    async def step(self) -> None:
198        """Execute one step of the simulation."""
199        self._sim_time += self._dt
200
201        # Process commands that are ready to be applied
202        commands_to_remove = []
203        for name, (target_command, application_time) in self._next_commands.items():
204            if self._sim_time >= application_time:
205                self._current_commands[name] = target_command
206                commands_to_remove.append(name)
207
208        # Remove processed commands
209        if commands_to_remove:
210            for name in commands_to_remove:
211                self._next_commands.pop(name)
212
213        mujoco.mj_forward(self._model, self._data)
214
215        # Sets the ctrl values from the current commands.
216        for name, target_command in self._current_commands.items():
217            joint_id = self._joint_name_to_id[name]
218            actuator_id = self._joint_id_to_actuator_id[joint_id]
219            kp = self._joint_name_to_kp[name]
220            kd = self._joint_name_to_kd[name]
221            current_position = self._data.joint(name).qpos
222            current_velocity = self._data.joint(name).qvel
223            target_torque = (
224                kp * (target_command["position"] - current_position)
225                + kd * (target_command["velocity"] - current_velocity)
226                + target_command["torque"]
227            )
228            if (max_torque := self._joint_name_to_max_torque.get(name)) is not None:
229                target_torque = np.clip(target_torque, -max_torque, max_torque)
230            logger.debug("Setting ctrl for actuator %s to %f", actuator_id, target_torque)
231            self._data.ctrl[actuator_id] = target_torque
232
233        # Step physics - allow other coroutines to run during computation
234
235        # for some reason running forward before step makes it more stable.
236        # It possibly computes some values that are needed for the step.
237        mujoco.mj_forward(self._model, self._data)
238        mujoco.mj_step(self._model, self._data)
239        if self._suspended:
240            # Find the root joint (floating_base)
241            for i in range(self._model.njnt):
242                if self._model.jnt_type[i] == mujoco.mjtJoint.mjJNT_FREE:
243                    self._data.qpos[i : i + 7] = [0.0, 0.0, self._start_height, 0.0, 0.0, 0.0, 1.0]
244                    self._data.qvel[i : i + 6] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
245                    break
246
247        return self._data

Execute one step of the simulation.

async def render(self) -> None:
249    async def render(self) -> None:
250        """Render the simulation asynchronously."""
251        if self._render_enabled:
252            self._viewer.render()

Render the simulation asynchronously.

async def capture_frame( self, camid: int = -1, depth: bool = False) -> tuple[numpy.ndarray, numpy.ndarray | None]:
254    async def capture_frame(self, camid: int = -1, depth: bool = False) -> tuple[np.ndarray, np.ndarray | None]:
255        """Capture a frame from the simulation using read_pixels.
256
257        Args:
258            camid: Camera ID to use (-1 for free camera)
259            depth: Whether to return depth information
260
261        Returns:
262            RGB image array (and optionally depth array) if depth=True
263        """
264        if self._render_mode != "offscreen" and self._render_enabled:
265            logger.warning("Capturing frames is more efficient in offscreen mode")
266
267        if camid is not None:
268            if camid == -1:
269                self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
270            else:
271                self._viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
272                self._viewer.cam.fixedcamid = camid
273
274        if depth:
275            rgb, depth_img = self._viewer.read_pixels(depth=True)
276            return rgb, depth_img
277        else:
278            rgb = self._viewer.read_pixels()
279            return rgb, None

Capture a frame from the simulation using read_pixels.

Args: camid: Camera ID to use (-1 for free camera) depth: Whether to return depth information

Returns: RGB image array (and optionally depth array) if depth=True

async def get_sensor_data(self, name: str) -> numpy.ndarray:
281    async def get_sensor_data(self, name: str) -> np.ndarray:
282        """Get data from a named sensor."""
283        if name not in self._sensor_name_to_id:
284            raise KeyError(f"Sensor '{name}' not found")
285        sensor_id = self._sensor_name_to_id[name]
286        return self._data.sensor(sensor_id).data.copy()

Get data from a named sensor.

async def get_actuator_state(self, joint_id: int) -> ActuatorState:
288    async def get_actuator_state(self, joint_id: int) -> ActuatorState:
289        """Get current state of an actuator using real joint ID."""
290        if joint_id not in self._joint_id_to_name:
291            raise KeyError(f"Joint ID {joint_id} not found in config mappings")
292
293        joint_name = self._joint_id_to_name[joint_id]
294        joint_data = self._data.joint(joint_name)
295
296        return ActuatorState(
297            position=float(joint_data.qpos)
298            + self._joint_name_to_pos_delta[joint_name]
299            + random.uniform(-self._joint_pos_noise, self._joint_pos_noise),
300            velocity=float(joint_data.qvel) + random.uniform(-self._joint_vel_noise, self._joint_vel_noise),
301        )

Get current state of an actuator using real joint ID.

async def command_actuators(self, commands: dict[int, ActuatorCommand]) -> None:
303    async def command_actuators(self, commands: dict[int, ActuatorCommand]) -> None:
304        """Command multiple actuators at once using real joint IDs."""
305        for joint_id, command in commands.items():
306            # Translate real joint ID to MuJoCo joint name
307            if joint_id not in self._joint_id_to_name:
308                logger.warning("Joint ID %d not found in config mappings", joint_id)
309                continue
310
311            joint_name = self._joint_id_to_name[joint_id]
312            actuator_name = f"{joint_name}_ctrl"
313            if actuator_name not in self._actuator_name_to_id:
314                logger.warning("Joint %s not found in MuJoCo model", actuator_name)
315                continue
316
317            # Calculate random delay and application time
318            delay = np.random.uniform(self._command_delay_min, self._command_delay_max)
319            application_time = self._sim_time + delay
320
321            self._next_commands[joint_name] = (command, application_time)

Command multiple actuators at once using real joint IDs.

async def configure_actuator( self, joint_id: int, configuration: ConfigureActuatorRequest) -> None:
323    async def configure_actuator(self, joint_id: int, configuration: ConfigureActuatorRequest) -> None:
324        """Configure an actuator using real joint ID."""
325        if joint_id not in self._joint_id_to_actuator_id:
326            raise KeyError(
327                f"Joint ID {joint_id} not found in config mappings. "
328                f"The available joint IDs are {self._joint_id_to_actuator_id.keys()}"
329            )
330
331        joint_name = self._joint_id_to_name[joint_id]
332        if "kp" in configuration:
333            self._joint_name_to_kp[joint_name] = configuration["kp"]
334        if "kd" in configuration:
335            self._joint_name_to_kd[joint_name] = configuration["kd"]
336        if "max_torque" in configuration:
337            self._joint_name_to_max_torque[joint_name] = configuration["max_torque"]

Configure an actuator using real joint ID.

sim_time: float
339    @property
340    def sim_time(self) -> float:
341        return self._sim_time
async def reset( self, xyz: tuple[float, float, float] | None = None, quat: tuple[float, float, float, float] | None = None, joint_pos: dict[str, float] | None = None, joint_vel: dict[str, float] | None = None) -> None:
343    async def reset(
344        self,
345        xyz: tuple[float, float, float] | None = None,
346        quat: tuple[float, float, float, float] | None = None,
347        joint_pos: dict[str, float] | None = None,
348        joint_vel: dict[str, float] | None = None,
349    ) -> None:
350        """Reset simulation to specified or default state."""
351        self._next_commands.clear()
352
353        mujoco.mj_resetData(self._model, self._data)
354
355        # Resets qpos.
356        qpos = np.zeros_like(self._data.qpos)
357        qpos[:3] = np.array([0.0, 0.0, self._start_height] if xyz is None else xyz)
358        qpos[3:7] = np.array([0.0, 0.0, 0.0, 1.0] if quat is None else quat)
359        qpos[7:] = np.zeros_like(self._data.qpos[7:])
360        if joint_pos is not None:
361            for joint_name, position in joint_pos.items():
362                self._data.joint(joint_name).qpos = position
363
364        # Resets qvel.
365        qvel = np.zeros_like(self._data.qvel)
366        if joint_vel is not None:
367            for joint_name, velocity in joint_vel.items():
368                self._data.joint(joint_name).qvel = velocity
369
370        # Resets qacc.
371        qacc = np.zeros_like(self._data.qacc)
372
373        # Runs one step.
374        self._data.qpos[:] = qpos
375        self._data.qvel[:] = qvel
376        self._data.qacc[:] = qacc
377        mujoco.mj_forward(self._model, self._data)

Reset simulation to specified or default state.

async def close(self) -> None:
379    async def close(self) -> None:
380        """Clean up simulation resources."""
381        if self._viewer is not None:
382            try:
383                self._viewer.close()
384            except Exception as e:
385                logger.error("Error closing viewer: %s", e)
386            self._viewer = None

Clean up simulation resources.

timestep: float
388    @property
389    def timestep(self) -> float:
390        return self._model.opt.timestep