未验证 提交 b040b1c3 编写于 作者: W Weiyuhong-1998 提交者: GitHub

feature(wyh): multi agent mujoco environment (#146)

* ma mujoco env and masac code

* env(wyh):ma mujoco agent id

* feature(wyh):maqac continuous

* fix(wyh):multi-mujoco add readme

* fix(wyh): td error

* fix(wyh)style

* fix(wyh):multi agent mujoco test
上级 02bd3300
......@@ -208,6 +208,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 21 | [gym_hybrid](https://github.com/thomashirtz/gym-hybrid) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_hybrid/moving_v0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_hybrid)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/env_tutorial/gym_hybrid_zh.html) |
| 22 | [GoBigger](https://github.com/opendilab/GoBigger) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen)![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](./dizoo/gobigger_overview.gif) | [opendilab link](https://github.com/opendilab/GoBigger-Challenge-2021/tree/main/di_baseline)<br>[env tutorial](https://gobigger.readthedocs.io/en/latest/index.html)<br>[环境指南](https://gobigger.readthedocs.io/zh_CN/latest/) |
| 23 | [gym_soccer](https://github.com/openai/gym-soccer) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_soccer/half_offensive.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_soccer)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/env_tutorial/gym_soccer_zh.html) |
| 24 |[multiagent_mujoco](https://github.com/schroederdewitt/multiagent_mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_mujoco/envs)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/env_tutorial/mujoco_zh.html) |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
......
......@@ -15,5 +15,5 @@ from .qtran import QTran
from .mavac import MAVAC
from .ngu import NGU
from .qac_dist import QACDIST
from .maqac import MAQAC
from .maqac import MAQAC, ContinuousMAQAC
from .model_based import EnsembleDynamicsModel
from typing import Union, Dict, Optional
from easydict import EasyDict
import numpy as np
import torch
import torch.nn as nn
......@@ -94,6 +96,215 @@ class MAQAC(nn.Module):
Overview:
Use bbservation and action tensor to predict output.
Parameter updates with QAC's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'``:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys:
- ``obs``, ``action`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of network forward.
Forward with ``'compute_actor'``, Necessary Keys (either):
- action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``.
- logit (:obj:`torch.Tensor`): Action's probabilities.
Forward with ``'compute_critic'``, Necessary Keys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Actor Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Critic Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``global_obs_shape``
- logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
def compute_actor(self, inputs: torch.Tensor) -> Dict:
r"""
Overview:
Use encoded embedding tensor to predict output.
Execute parameter updates with ``'compute_actor'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
``hidden_size = actor_head_hidden_size``
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of forward pass encoder and head.
ReturnsKeys (either):
- action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
Examples:
>>> # Regression mode
>>> model = QAC(64, 64, 'regression')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 64])
>>> # Reparameterization Mode
>>> model = QAC(64, 64, 'reparameterization')
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> actor_outputs['logit'][0].shape # mu
>>> torch.Size([4, 64])
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
"""
action_mask = inputs['obs']['action_mask']
x = self.actor(inputs['obs']['agent_state'])
return {'logit': x['logit'], 'action_mask': action_mask}
def compute_critic(self, inputs: Dict) -> Dict:
r"""
Overview:
Execute parameter updates with ``'compute_critic'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- ``obs``, ``action`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Q-value output.
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Examples:
>>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
>>> model(inputs, mode='compute_critic')['q_value'] # q value
tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""
if self.twin_critic:
x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic]
else:
x = self.critic(inputs['obs']['global_state'])['logit']
return {'q_value': x}
@MODEL_REGISTRY.register('maqac_continuous')
class ContinuousMAQAC(nn.Module):
r"""
Overview:
The Continuous MAQAC model.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic']
def __init__(
self,
agent_obs_shape: Union[int, SequenceType],
global_obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
actor_head_type: str,
twin_critic: bool = False,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
) -> None:
r"""
Overview:
Init the QAC Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ),
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization`` or ``hybrid`` .
- twin_critic (:obj:`bool`): Whether include twin critic.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for actor's nn.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
- critic_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for critic's nn.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
"""
super(ContinuousMAQAC, self).__init__()
obs_shape: int = squeeze(agent_obs_shape)
global_obs_shape: int = squeeze(global_obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.actor_head_type = actor_head_type
assert self.actor_head_type in ['regression', 'reparameterization']
if self.actor_head_type == 'regression': # DDPG, TD3
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
RegressionHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
final_tanh=True,
activation=activation,
norm_type=norm_type
)
)
else: # SAC
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
)
)
self.twin_critic = twin_critic
critic_input_size = global_obs_shape + action_shape
if self.twin_critic:
self.critic = nn.ModuleList()
for _ in range(2):
self.critic.append(
nn.Sequential(
nn.Linear(critic_input_size, critic_head_hidden_size), activation,
RegressionHead(
critic_head_hidden_size,
1,
critic_head_layer_num,
final_tanh=False,
activation=activation,
norm_type=norm_type
)
)
)
else:
self.critic = nn.Sequential(
nn.Linear(critic_input_size, critic_head_hidden_size), activation,
RegressionHead(
critic_head_hidden_size,
1,
critic_head_layer_num,
final_tanh=False,
activation=activation,
norm_type=norm_type
)
)
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
r"""
Overview:
Use observation and action tensor to predict output.
Parameter updates with QAC's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'``:
- inputs (:obj:`torch.Tensor`):
......@@ -167,11 +378,16 @@ class MAQAC(nn.Module):
- action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
- logit + action_args
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
- logit (:obj:`Union[list, torch.Tensor]`):
- case1(continuous space, list): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
- case2(hybrid space, torch.Tensor): :math:`(B, N1)`, where N1 is action_type_shape
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
- action_args (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where N2 is action_args_shape
(action_args are continuous real value)
Examples:
>>> # Regression mode
>>> model = QAC(64, 64, 'regression')
......@@ -187,9 +403,13 @@ class MAQAC(nn.Module):
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
"""
action_mask = inputs['obs']['action_mask']
x = self.actor(inputs['obs']['agent_state'])
return {'logit': x['logit'], 'action_mask': action_mask}
inputs = inputs['agent_state']
if self.actor_head_type == 'regression':
x = self.actor(inputs)
return {'action': x['pred']}
else:
x = self.actor(inputs)
return {'logit': [x['mu'], x['sigma']]}
def compute_critic(self, inputs: Dict) -> Dict:
r"""
......@@ -197,11 +417,17 @@ class MAQAC(nn.Module):
Execute parameter updates with ``'compute_critic'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- ``obs``, ``action`` encoded tensors.
- inputs (:obj: `Dict`): ``obs``, ``action`` and ``logit` tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Q-value output.
ArgumentsKeys:
- necessary:
- obs: (:obj:`torch.Tensor`): 2-dim vector observation
- action (:obj:`Union[torch.Tensor, Dict]`): action from actor
- optional:
- logit (:obj:`torch.Tensor`): discrete action logit
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
......@@ -212,13 +438,16 @@ class MAQAC(nn.Module):
Examples:
>>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
>>> model(inputs, mode='compute_critic')['q_value'] # q value
tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
>>> model(inputs, mode='compute_critic')['q_value'] # q value
>>> tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""
obs, action = inputs['obs']['global_state'], inputs['action']
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=-1)
if self.twin_critic:
x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic]
x = [m(x)['pred'] for m in self.critic]
else:
x = self.critic(inputs['obs']['global_state'])['logit']
x = self.critic(x)['pred']
return {'q_value': x}
......@@ -78,6 +78,7 @@ class CQLPolicy(SACPolicy):
# on-policy setting influences the behaviour of buffer.
# Default False in SAC.
on_policy=False,
multi_agent=False,
# (bool type) priority: Determine whether to use priority in buffer sample.
# Default False in SAC.
priority=False,
......
......@@ -583,6 +583,7 @@ class SACPolicy(Policy):
# (int) Number of training samples(randomly collected) in replay buffer when training starts.
# Default 10000 in SAC.
random_collect_size=10000,
multi_agent=False,
model=dict(
# (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation.
# Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one .
......@@ -1042,7 +1043,10 @@ class SACPolicy(Policy):
return {i: d for i, d in zip(data_id, output)}
def default_model(self) -> Tuple[str, List[str]]:
return 'qac', ['ding.model.template.qac']
if self._cfg.multi_agent:
return 'maqac_continuous', ['ding.model.template.maqac']
else:
return 'qac', ['ding.model.template.qac']
def _monitor_vars_learn(self) -> List[str]:
r"""
......
......@@ -267,11 +267,17 @@ def v_1step_td_error(
) -> torch.Tensor:
v, next_v, reward, done, weight = data
if weight is None:
weight = torch.ones_like(reward)
if done is not None:
target_v = gamma * (1 - done) * next_v + reward
weight = torch.ones_like(v)
if len(v.shape) == len(reward.shape):
if done is not None:
target_v = gamma * (1 - done) * next_v + reward
else:
target_v = gamma * next_v + reward
else:
target_v = gamma * next_v + reward
if done is not None:
target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1)
else:
target_v = gamma * next_v + reward.unsqueeze(1)
td_error_per_sample = criterion(v, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
......
......@@ -196,6 +196,26 @@ def test_v_1step_td():
assert isinstance(v.grad, torch.Tensor)
@pytest.mark.unittest
def test_v_1step_multi_agent_td():
batch_size = 5
agent_num = 2
v = torch.randn(batch_size, agent_num).requires_grad_(True)
next_v = torch.randn(batch_size, agent_num)
reward = torch.rand(batch_size)
done = torch.zeros(batch_size)
data = v_1step_td_data(v, next_v, reward, done, None)
loss, td_error_per_sample = v_1step_td_error(data, 0.99)
assert loss.shape == ()
assert v.grad is None
loss.backward()
assert isinstance(v.grad, torch.Tensor)
data = v_1step_td_data(v, next_v, reward, None, None)
loss, td_error_per_sample = v_1step_td_error(data, 0.99)
loss.backward()
assert isinstance(v.grad, torch.Tensor)
@pytest.mark.unittest
def test_v_nstep_td():
batch_size = 5
......
## Multi Agent Mujoco Env
Multi Agent Mujoco is an environment for Continuous Multi-Agent Robotic Control, based on OpenAI's Mujoco Gym environments.
The environment is described in the paper [Deep Multi-Agent Reinforcement Learning for Decentralized Continuous Cooperative Control](https://arxiv.org/abs/2003.06709) by Christian Schroeder de Witt, Bei Peng, Pierre-Alexandre Kamienny, Philip Torr, Wendelin Böhmer and Shimon Whiteson, Torr Vision Group and Whiteson Research Lab, University of Oxford, 2020
You can find more details in [Multi-Agent Mujoco Environment](https://github.com/schroederdewitt/multiagent_mujoco)
from easydict import EasyDict
from ding.entry.serial_entry import serial_pipeline
ant_sac_default_config = dict(
exp_name='multi_mujoco_ant_2x4',
env=dict(
scenario='Ant-v2',
agent_conf="2x4d",
agent_obsk=2,
add_agent_id=False,
episode_limit=1000,
collector_env_num=1,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=6000,
),
policy=dict(
cuda=True,
random_collect_size=0,
multi_agent=True,
model=dict(
agent_obs_shape=54,
global_obs_shape=111,
action_shape=4,
twin_critic=True,
actor_head_type='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
learn=dict(
update_per_collect=10,
batch_size=256,
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
learning_rate_alpha=3e-4,
ignore_done=False,
target_theta=0.005,
discount_factor=0.99,
alpha=0.2,
reparameterization=True,
auto_alpha=True,
log_space=True,
),
collect=dict(
n_sample=400,
unroll_len=1,
),
command=dict(),
eval=dict(evaluator=dict(eval_freq=100, )),
other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
),
)
ant_sac_default_config = EasyDict(ant_sac_default_config)
main_config = ant_sac_default_config
ant_sac_default_create_config = dict(
env=dict(
type='mujoco_multi',
import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='sac',
import_names=['ding.policy.sac'],
),
replay_buffer=dict(type='naive', ),
)
ant_sac_default_create_config = EasyDict(ant_sac_default_create_config)
create_config = ant_sac_default_create_config
if __name__ == '__main__':
serial_pipeline((main_config, create_config), seed=0)
from .mujoco_multi import MujocoMulti
from .coupled_half_cheetah import CoupledHalfCheetah
from .manyagent_swimmer import ManyAgentSwimmerEnv
from .manyagent_ant import ManyAgentAntEnv
<!-- Cheetah Model
The state space is populated with joints in the order that they are
defined in this file. The actuators also operate on joints.
State-Space (name/joint/parameter):
- rootx slider position (m)
- rootz slider position (m)
- rooty hinge angle (rad)
- bthigh hinge angle (rad)
- bshin hinge angle (rad)
- bfoot hinge angle (rad)
- fthigh hinge angle (rad)
- fshin hinge angle (rad)
- ffoot hinge angle (rad)
- rootx slider velocity (m/s)
- rootz slider velocity (m/s)
- rooty hinge angular velocity (rad/s)
- bthigh hinge angular velocity (rad/s)
- bshin hinge angular velocity (rad/s)
- bfoot hinge angular velocity (rad/s)
- fthigh hinge angular velocity (rad/s)
- fshin hinge angular velocity (rad/s)
- ffoot hinge angular velocity (rad/s)
Actuators (name/actuator/parameter):
- bthigh hinge torque (N m)
- bshin hinge torque (N m)
- bfoot hinge torque (N m)
- fthigh hinge torque (N m)
- fshin hinge torque (N m)
- ffoot hinge torque (N m)
-->
<mujoco model="cheetah">
<compiler angle="radian" coordinate="local" inertiafromgeom="true" settotalmass="14"/>
<default>
<joint armature=".1" damping=".01" limited="true" solimplimit="0 .8 .03" solreflimit=".02 1" stiffness="8"/>
<geom conaffinity="0" condim="3" contype="1" friction=".4 .1 .1" rgba="0.8 0.6 .4 1" solimp="0.0 0.8 0.01" solref="0.02 1"/>
<motor ctrllimited="true" ctrlrange="-1 1"/>
</default>
<size nstack="300000" nuser_geom="1"/>
<option gravity="0 0 -9.81" timestep="0.01"/>
<asset>
<texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
<body name="torso" pos="0 -1 .7">
<site name="t1" pos="0.0 0 0" size="0.1"/>
<camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
<joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx" pos="0 0 0" stiffness="0" type="slide"/>
<joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz" pos="0 0 0" stiffness="0" type="slide"/>
<joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty" pos="0 0 0" stiffness="0" type="hinge"/>
<geom fromto="-.5 0 0 .5 0 0" name="torso" size="0.046" type="capsule"/>
<geom axisangle="0 1 0 .87" name="head" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
<!-- <site name='tip' pos='.15 0 .11'/>-->
<body name="bthigh" pos="-.5 0 0">
<joint axis="0 1 0" damping="6" name="bthigh" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
<geom axisangle="0 1 0 -3.8" name="bthigh" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
<body name="bshin" pos=".16 0 -.25">
<joint axis="0 1 0" damping="4.5" name="bshin" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
<geom axisangle="0 1 0 -2.03" name="bshin" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
<body name="bfoot" pos="-.28 0 -.14">
<joint axis="0 1 0" damping="3" name="bfoot" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
<geom axisangle="0 1 0 -.27" name="bfoot" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
</body>
</body>
</body>
<body name="fthigh" pos=".5 0 0">
<joint axis="0 1 0" damping="4.5" name="fthigh" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
<geom axisangle="0 1 0 .52" name="fthigh" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
<body name="fshin" pos="-.14 0 -.24">
<joint axis="0 1 0" damping="3" name="fshin" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
<geom axisangle="0 1 0 -.6" name="fshin" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
<body name="ffoot" pos=".13 0 -.18">
<joint axis="0 1 0" damping="1.5" name="ffoot" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
<geom axisangle="0 1 0 -.6" name="ffoot" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
</body>
</body>
</body>
</body>
<!-- second cheetah definition -->
<body name="torso2" pos="0 1 .7">
<site name="t2" pos="0 0 0" size="0.1"/>
<camera name="track2" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
<joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx2" pos="0 0 0" stiffness="0" type="slide"/>
<joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz2" pos="0 0 0" stiffness="0" type="slide"/>
<joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty2" pos="0 0 0" stiffness="0" type="hinge"/>
<geom fromto="-.5 0 0 .5 0 0" name="torso2" size="0.046" type="capsule"/>
<geom axisangle="0 1 0 .87" name="head2" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
<!-- <site name='tip' pos='.15 0 .11'/>-->
<body name="bthigh2" pos="-.5 0 0">
<joint axis="0 1 0" damping="6" name="bthigh2" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
<geom axisangle="0 1 0 -3.8" name="bthigh2" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
<body name="bshin2" pos=".16 0 -.25">
<joint axis="0 1 0" damping="4.5" name="bshin2" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
<geom axisangle="0 1 0 -2.03" name="bshin2" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
<body name="bfoot2" pos="-.28 0 -.14">
<joint axis="0 1 0" damping="3" name="bfoot2" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
<geom axisangle="0 1 0 -.27" name="bfoot2" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
</body>
</body>
</body>
<body name="fthigh2" pos=".5 0 0">
<joint axis="0 1 0" damping="4.5" name="fthigh2" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
<geom axisangle="0 1 0 .52" name="fthigh2" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
<body name="fshin2" pos="-.14 0 -.24">
<joint axis="0 1 0" damping="3" name="fshin2" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
<geom axisangle="0 1 0 -.6" name="fshin2" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
<body name="ffoot2" pos=".13 0 -.18">
<joint axis="0 1 0" damping="1.5" name="ffoot2" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
<geom axisangle="0 1 0 -.6" name="ffoot2" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
</body>
</body>
</body>
</body>
</worldbody>
<tendon>
<spatial name="tendon1" width="0.05" rgba=".95 .3 .3 1" limited="true" range="1.5 3.5" stiffness="0.1">
<site site="t1"/>
<site site="t2"/>
</spatial>
</tendon>-
<actuator>
<motor gear="120" joint="bthigh" name="bthigh"/>
<motor gear="90" joint="bshin" name="bshin"/>
<motor gear="60" joint="bfoot" name="bfoot"/>
<motor gear="120" joint="fthigh" name="fthigh"/>
<motor gear="60" joint="fshin" name="fshin"/>
<motor gear="30" joint="ffoot" name="ffoot"/>
<motor gear="120" joint="bthigh2" name="bthigh2"/>
<motor gear="90" joint="bshin2" name="bshin2"/>
<motor gear="60" joint="bfoot2" name="bfoot2"/>
<motor gear="120" joint="fthigh2" name="fthigh2"/>
<motor gear="60" joint="fshin2" name="fshin2"/>
<motor gear="30" joint="ffoot2" name="ffoot2"/>
</actuator>
</mujoco>
\ No newline at end of file
<mujoco model="ant">
<size nconmax="200"/>
<compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
<option integrator="RK4" timestep="0.01"/>
<custom>
<numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
</custom>
<default>
<joint armature="1" damping="1" limited="true"/>
<geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
</default>
<asset>
<texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
<body name="torso" pos="0 0 0.75">
<camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
<!--<geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>-->
<joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
<body name="front_left_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
<body name="aux_1" pos="0.2 0.2 0">
<joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
<body pos="0.2 0.2 0">
<joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="right_back_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
<body name="aux_4" pos="0.2 -0.2 0">
<joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
<body pos="0.2 -0.2 0">
<joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="midx" pos="0.0 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<!--<joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>-->
<body name="front_right_legx" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 0.0 0.2 0.0" name="aux_2_geomx" size="0.08" type="capsule"/>
<body name="aux_2x" pos="0.0 0.2 0">
<joint axis="0 0 1" name="hip_2x" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geomx" size="0.08" type="capsule"/>
<body pos="-0.2 0.2 0">
<joint axis="1 1 0" name="ankle_2x" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geomx" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="back_legx" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 0.0 -0.2 0.0" name="aux_3_geomx" size="0.08" type="capsule"/>
<body name="aux_3x" pos="0.0 -0.2 0">
<joint axis="0 0 1" name="hip_3x" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geomx" size="0.08" type="capsule"/>
<body pos="-0.2 -0.2 0">
<joint axis="-1 1 0" name="ankle_3x" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geomx" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="mid" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<!--<joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>-->
<!--<body name="front_right_leg" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
<body name="aux_2" pos="-0.2 0.2 0">
<joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
<body pos="-0.2 0.2 0">
<joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="back_leg" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
<body name="aux_3" pos="-0.2 -0.2 0">
<joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
<body pos="-0.2 -0.2 0">
<joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>-->
<body name="front_right_leg" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 0.0 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
<body name="aux_2" pos="0.0 0.2 0">
<joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
<body pos="-0.2 0.2 0">
<joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="back_leg" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 0.0 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
<body name="aux_3" pos="0.0 -0.2 0">
<joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
<body pos="-0.2 -0.2 0">
<joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
</body>
</body>
</body>
</worldbody>
<actuator>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
</actuator>
</mujoco>
\ No newline at end of file
<mujoco model="ant">
<size nconmax="200"/>
<compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
<option integrator="RK4" timestep="0.005"/>
<custom>
<numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
</custom>
<default>
<joint armature="1" damping="1" limited="true"/>
<geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
</default>
<asset>
<texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
<body name="torso_0" pos="0 0 0.75">
<camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
<!--<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>-->
<joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
<body name="front_left_leg_0" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux1_geom_0" size="0.08" type="capsule"/>
<body name="aux1_0" pos="0.2 0.2 0">
<joint axis="0 0 1" name="hip1_0" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom_0" size="0.08" type="capsule"/>
<body pos="0.2 0.2 0">
<joint axis="-1 1 0" name="ankle1_0" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom_0" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="right_back_leg_0" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux2_geom_0" size="0.08" type="capsule"/>
<body name="aux2_0" pos="0.2 -0.2 0">
<joint axis="0 0 1" name="hip2_0" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom_0" size="0.08" type="capsule"/>
<body pos="0.2 -0.2 0">
<joint axis="1 1 0" name="ankle2_0" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="second_ankle_geom_0" size="0.08" type="capsule"/>
</body>
</body>
</body>
{{ body }}
</body>
</worldbody>
<actuator>
{{ actuators }}
</actuator>
</mujoco>
\ No newline at end of file
<mujoco model="ant">
<compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
<option integrator="RK4" timestep="0.01"/>
<custom>
<numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
</custom>
<default>
<joint armature="1" damping="1" limited="true"/>
<geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
</default>
<asset>
<texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
<body name="torso" pos=" 0 0.75">
<camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
<!--<geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>-->
<joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
<body name="front_left_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
<body name="aux_1" pos="0.2 0.2 0">
<joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
<body pos="0.2 0.2 0">
<joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="right_back_leg" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
<body name="aux_4" pos="0.2 -0.2 0">
<joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
<body pos="0.2 -0.2 0">
<joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
<geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="mid" pos="0.0 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>
<body name="front_right_leg" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
<body name="aux_2" pos="-0.2 0.2 0">
<joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
<body pos="-0.2 0.2 0">
<joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="back_leg" pos="-1 0 0">
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
<body name="aux_3" pos="-0.2 -0.2 0">
<joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
<body pos="-0.2 -0.2 0">
<joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
</body>
</body>
</body>
</body>
</body>
</worldbody>
<actuator>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
</actuator>
</mujoco>
\ No newline at end of file
<mujoco model="swimmer">
<compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
<option collision="predefined" density="4000" integrator="RK4" timestep="0.005" viscosity="0.1"/>
<default>
<geom conaffinity="1" condim="1" contype="1" material="geom" rgba="0.8 0.6 .4 1"/>
<joint armature='0.1' />
</default>
<asset>
<texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="30 30" texture="texplane"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 -0.1" rgba="0.8 0.9 0.8 1" size="40 40 0.1" type="plane"/>
<!-- ================= SWIMMER ================= /-->
<body name="torso" pos="0 0 0">
<geom density="1000" fromto="1.5 0 0 0.5 0 0" size="0.1" type="capsule"/>
<joint axis="1 0 0" name="slider1" pos="0 0 0" type="slide"/>
<joint axis="0 1 0" name="slider2" pos="0 0 0" type="slide"/>
<joint axis="0 0 1" name="rot" pos="0 0 0" type="hinge"/>
<body name="mid0" pos="0.5 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot0" pos="0 0 0" range="-100 100" type="hinge"/>
{{ body }}
</body>
</body>
</worldbody>
<actuator>
{{ actuators }}
</actuator>
</mujoco>
\ No newline at end of file
<mujoco model="swimmer">
<compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
<option collision="predefined" density="4000" integrator="RK4" timestep="0.01" viscosity="0.1"/>
<default>
<geom conaffinity="1" condim="1" contype="1" material="geom" rgba="0.8 0.6 .4 1"/>
<joint armature='0.1' />
</default>
<asset>
<texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="30 30" texture="texplane"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 -0.1" rgba="0.8 0.9 0.8 1" size="40 40 0.1" type="plane"/>
<!-- ================= SWIMMER ================= /-->
<body name="torso" pos="0 0 0">
<geom density="1000" fromto="1.5 0 0 0.5 0 0" size="0.1" type="capsule"/>
<joint axis="1 0 0" name="slider1" pos="0 0 0" type="slide"/>
<joint axis="0 1 0" name="slider2" pos="0 0 0" type="slide"/>
<joint axis="0 0 1" name="rot" pos="0 0 0" type="hinge"/>
<body name="mid1" pos="0.5 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot0" pos="0 0 0" range="-100 100" type="hinge"/>
<body name="mid2" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 -1" limited="true" name="rot1" pos="0 0 0" range="-100 100" type="hinge"/>
<body name="mid3" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>
<body name="back" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot3" pos="0 0 0" range="-100 100" type="hinge"/>
</body>
</body>
</body>
</body>
</body>
</worldbody>
<actuator>
<motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot0"/>
<motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot1"/>
<motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot2"/>
<motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot3"/>
</actuator>
</mujoco>
\ No newline at end of file
<mujoco model="swimmer">
<compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
<option collision="predefined" density="4000" integrator="RK4" timestep="0.01" viscosity="0.1"/>
<default>
<geom conaffinity="1" condim="1" contype="1" material="geom" rgba="0.8 0.6 .4 1"/>
<joint armature='0.1' />
</default>
<asset>
<texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
<material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="30 30" texture="texplane"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 -0.1" rgba="0.8 0.9 0.8 1" size="40 40 0.1" type="plane"/>
<!-- ================= SWIMMER ================= /-->
<body name="torso" pos="0 0 0">
<geom density="1000" fromto="1.5 0 0 0.5 0 0" size="0.1" type="capsule"/>
<joint axis="1 0 0" name="slider1" pos="0 0 0" type="slide"/>
<joint axis="0 1 0" name="slider2" pos="0 0 0" type="slide"/>
<joint axis="0 0 1" name="rot" pos="0 0 0" type="hinge"/>
<body name="mid1" pos="0.5 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot0" pos="0 0 0" range="-100 100" type="hinge"/>
<body name="mid2" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 -1" limited="true" name="rot1" pos="0 0 0" range="-100 100" type="hinge"/>
<body name="back" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>
</body>
</body>
</body>
</body>
</worldbody>
<actuator>
<motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot0"/>
<motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot1"/>
<motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot2"/>
</actuator>
</mujoco>
\ No newline at end of file
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
import os
class CoupledHalfCheetah(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs):
mujoco_env.MujocoEnv.__init__(
self, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'coupled_half_cheetah.xml'), 5
)
utils.EzPickle.__init__(self)
def step(self, action):
xposbefore1 = self.sim.data.qpos[0]
xposbefore2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]
self.do_simulation(action, self.frame_skip)
xposafter1 = self.sim.data.qpos[0]
xposafter2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]
ob = self._get_obs()
reward_ctrl1 = -0.1 * np.square(action[0:len(action) // 2]).sum()
reward_ctrl2 = -0.1 * np.square(action[len(action) // 2:]).sum()
reward_run1 = (xposafter1 - xposbefore1) / self.dt
reward_run2 = (xposafter2 - xposbefore2) / self.dt
reward = (reward_ctrl1 + reward_ctrl2) / 2.0 + (reward_run1 + reward_run2) / 2.0
done = False
return ob, reward, done, dict(
reward_run1=reward_run1, reward_ctrl1=reward_ctrl1, reward_run2=reward_run2, reward_ctrl2=reward_ctrl2
)
def _get_obs(self):
return np.concatenate([
self.sim.data.qpos.flat[1:],
self.sim.data.qvel.flat,
])
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
self.set_state(qpos, qvel)
return self._get_obs()
def viewer_setup(self):
self.viewer.cam.distance = self.model.stat.extent * 0.5
def get_env_info(self):
return {"episode_limit": self.episode_limit}
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
from jinja2 import Template
import os
class ManyAgentAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs):
agent_conf = kwargs.get("agent_conf")
n_agents = int(agent_conf.split("x")[0])
n_segs_per_agents = int(agent_conf.split("x")[1])
n_segs = n_agents * n_segs_per_agents
# Check whether asset file exists already, otherwise create it
asset_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'assets',
'manyagent_ant_{}_agents_each_{}_segments.auto.xml'.format(n_agents, n_segs_per_agents)
)
#if not os.path.exists(asset_path):
print("Auto-Generating Manyagent Ant asset with {} segments at {}.".format(n_segs, asset_path))
self._generate_asset(n_segs=n_segs, asset_path=asset_path)
#asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
# 'manyagent_swimmer.xml')
mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
utils.EzPickle.__init__(self)
def _generate_asset(self, n_segs, asset_path):
template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'manyagent_ant.xml.template')
with open(template_path, "r") as f:
t = Template(f.read())
body_str_template = """
<body name="torso_{:d}" pos="-1 0 0">
<!--<joint axis="0 1 0" name="nnn_{:d}" pos="0.0 0.0 0.0" range="-1 1" type="hinge"/>-->
<geom density="100" fromto="1 0 0 0 0 0" size="0.1" type="capsule"/>
<body name="front_right_leg_{:d}" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.0 0.2 0.0" name="aux1_geom_{:d}" size="0.08" type="capsule"/>
<body name="aux_2_{:d}" pos="0.0 0.2 0">
<joint axis="0 0 1" name="hip1_{:d}" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom_{:d}" size="0.08" type="capsule"/>
<body pos="-0.2 0.2 0">
<joint axis="1 1 0" name="ankle1_{:d}" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom_{:d}" size="0.08" type="capsule"/>
</body>
</body>
</body>
<body name="back_leg_{:d}" pos="0 0 0">
<geom fromto="0.0 0.0 0.0 0.0 -0.2 0.0" name="aux2_geom_{:d}" size="0.08" type="capsule"/>
<body name="aux2_{:d}" pos="0.0 -0.2 0">
<joint axis="0 0 1" name="hip2_{:d}" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom_{:d}" size="0.08" type="capsule"/>
<body pos="-0.2 -0.2 0">
<joint axis="-1 1 0" name="ankle2_{:d}" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
<geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom_{:d}" size="0.08" type="capsule"/>
</body>
</body>
</body>
"""
body_close_str_template = "</body>\n"
actuator_str_template = """\t <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip1_{:d}" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle1_{:d}" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip2_{:d}" gear="150"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle2_{:d}" gear="150"/>\n"""
body_str = ""
for i in range(1, n_segs):
body_str += body_str_template.format(*([i] * 16))
body_str += body_close_str_template * (n_segs - 1)
actuator_str = ""
for i in range(n_segs):
actuator_str += actuator_str_template.format(*([i] * 8))
rt = t.render(body=body_str, actuators=actuator_str)
with open(asset_path, "w") as f:
f.write(rt)
pass
def step(self, a):
xposbefore = self.get_body_com("torso_0")[0]
self.do_simulation(a, self.frame_skip)
xposafter = self.get_body_com("torso_0")[0]
forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = .5 * np.square(a).sum()
contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
survive_reward = 1.0
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
state = self.state_vector()
notdone = np.isfinite(state).all() \
and state[2] >= 0.2 and state[2] <= 1.0
done = not notdone
ob = self._get_obs()
return ob, reward, done, dict(
reward_forward=forward_reward,
reward_ctrl=-ctrl_cost,
reward_contact=-contact_cost,
reward_survive=survive_reward
)
def _get_obs(self):
return np.concatenate(
[
self.sim.data.qpos.flat[2:],
self.sim.data.qvel.flat,
np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
]
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
self.set_state(qpos, qvel)
return self._get_obs()
def viewer_setup(self):
self.viewer.cam.distance = self.model.stat.extent * 0.5
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
import os
from jinja2 import Template
class ManyAgentSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs):
agent_conf = kwargs.get("agent_conf")
n_agents = int(agent_conf.split("x")[0])
n_segs_per_agents = int(agent_conf.split("x")[1])
n_segs = n_agents * n_segs_per_agents
# Check whether asset file exists already, otherwise create it
asset_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'assets',
'manyagent_swimmer_{}_agents_each_{}_segments.auto.xml'.format(n_agents, n_segs_per_agents)
)
# if not os.path.exists(asset_path):
print("Auto-Generating Manyagent Swimmer asset with {} segments at {}.".format(n_segs, asset_path))
self._generate_asset(n_segs=n_segs, asset_path=asset_path)
#asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
# 'manyagent_swimmer.xml')
mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
utils.EzPickle.__init__(self)
def _generate_asset(self, n_segs, asset_path):
template_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'manyagent_swimmer.xml.template'
)
with open(template_path, "r") as f:
t = Template(f.read())
body_str_template = """
<body name="mid{:d}" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 {:d}" limited="true" name="rot{:d}" pos="0 0 0" range="-100 100" type="hinge"/>
"""
body_end_str_template = """
<body name="back" pos="-1 0 0">
<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
<joint axis="0 0 1" limited="true" name="rot{:d}" pos="0 0 0" range="-100 100" type="hinge"/>
</body>
"""
body_close_str_template = "</body>\n"
actuator_str_template = """\t <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot{:d}"/>\n"""
body_str = ""
for i in range(1, n_segs - 1):
body_str += body_str_template.format(i, (-1) ** (i + 1), i)
body_str += body_end_str_template.format(n_segs - 1)
body_str += body_close_str_template * (n_segs - 2)
actuator_str = ""
for i in range(n_segs):
actuator_str += actuator_str_template.format(i)
rt = t.render(body=body_str, actuators=actuator_str)
with open(asset_path, "w") as f:
f.write(rt)
pass
def step(self, a):
ctrl_cost_coeff = 0.0001
xposbefore = self.sim.data.qpos[0]
self.do_simulation(a, self.frame_skip)
xposafter = self.sim.data.qpos[0]
reward_fwd = (xposafter - xposbefore) / self.dt
reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
reward = reward_fwd + reward_ctrl
ob = self._get_obs()
return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
def _get_obs(self):
qpos = self.sim.data.qpos
qvel = self.sim.data.qvel
return np.concatenate([qpos.flat[2:], qvel.flat])
def reset_model(self):
self.set_state(
self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
)
return self._get_obs()
from functools import partial
import gym
from gym.spaces import Box
from gym.wrappers import TimeLimit
import numpy as np
from .multiagentenv import MultiAgentEnv
from .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs
# using code from https://github.com/ikostrikov/pytorch-ddpg-naf
class NormalizedActions(gym.ActionWrapper):
def _action(self, action):
action = (action + 1) / 2
action *= (self.action_space.high - self.action_space.low)
action += self.action_space.low
return action
def action(self, action_):
return self._action(action_)
def _reverse_action(self, action):
action -= self.action_space.low
action /= (self.action_space.high - self.action_space.low)
action = action * 2 - 1
return action
class MujocoMulti(MultiAgentEnv):
def __init__(self, batch_size=None, **kwargs):
super().__init__(batch_size, **kwargs)
self.add_agent_id = kwargs["env_args"]["add_agent_id"]
self.scenario = kwargs["env_args"]["scenario"] # e.g. Ant-v2
self.agent_conf = kwargs["env_args"]["agent_conf"] # e.g. '2x3'
self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges(
self.scenario, self.agent_conf
)
self.n_agents = len(self.agent_partitions)
self.n_actions = max([len(l) for l in self.agent_partitions])
self.obs_add_global_pos = kwargs["env_args"].get("obs_add_global_pos", False)
self.agent_obsk = kwargs["env_args"].get(
"agent_obsk", None
) # if None, fully observable else k>=0 implies observe nearest k agents or joints
self.agent_obsk_agents = kwargs["env_args"].get(
"agent_obsk_agents", False
) # observe full k nearest agents (True) or just single joints (False)
if self.agent_obsk is not None:
self.k_categories_label = kwargs["env_args"].get("k_categories")
if self.k_categories_label is None:
if self.scenario in ["Ant-v2", "manyagent_ant"]:
self.k_categories_label = "qpos,qvel,cfrc_ext|qpos"
elif self.scenario in ["Humanoid-v2", "HumanoidStandup-v2"]:
self.k_categories_label = "qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos"
elif self.scenario in ["Reacher-v2"]:
self.k_categories_label = "qpos,qvel,fingertip_dist|qpos"
elif self.scenario in ["coupled_half_cheetah"]:
self.k_categories_label = "qpos,qvel,ten_J,ten_length,ten_velocity|"
else:
self.k_categories_label = "qpos,qvel|qpos"
k_split = self.k_categories_label.split("|")
self.k_categories = [k_split[k if k < len(k_split) else -1].split(",") for k in range(self.agent_obsk + 1)]
self.global_categories_label = kwargs["env_args"].get("global_categories")
self.global_categories = self.global_categories_label.split(
","
) if self.global_categories_label is not None else []
if self.agent_obsk is not None:
self.k_dicts = [
get_joints_at_kdist(
agent_id,
self.agent_partitions,
self.mujoco_edges,
k=self.agent_obsk,
kagents=False,
) for agent_id in range(self.n_agents)
]
# load scenario from script
self.episode_limit = self.args.episode_limit
self.env_version = kwargs["env_args"].get("env_version", 2)
if self.env_version == 2:
try:
self.wrapped_env = NormalizedActions(gym.make(self.scenario))
except gym.error.Error: # env not in gym
if self.scenario in ["manyagent_ant"]:
from .manyagent_ant import ManyAgentAntEnv as this_env
elif self.scenario in ["manyagent_swimmer"]:
from .manyagent_swimmer import ManyAgentSwimmerEnv as this_env
elif self.scenario in ["coupled_half_cheetah"]:
from .coupled_half_cheetah import CoupledHalfCheetah as this_env
else:
raise NotImplementedError('Custom env not implemented!')
self.wrapped_env = NormalizedActions(
TimeLimit(this_env(**kwargs["env_args"]), max_episode_steps=self.episode_limit)
)
else:
assert False, "not implemented!"
self.timelimit_env = self.wrapped_env.env
self.timelimit_env._max_episode_steps = self.episode_limit
self.env = self.timelimit_env.env
self.timelimit_env.reset()
self.obs_size = self.get_obs_size()
# COMPATIBILITY
self.n = self.n_agents
self.observation_space = [
Box(low=np.array([-10] * self.n_agents), high=np.array([10] * self.n_agents)) for _ in range(self.n_agents)
]
acdims = [len(ap) for ap in self.agent_partitions]
self.action_space = tuple(
[
Box(
self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])],
self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])]
) for a in range(self.n_agents)
]
)
pass
def step(self, actions):
# need to remove dummy actions that arise due to unequal action vector sizes across agents
flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)])
obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions)
self.steps += 1
info = {}
info.update(info_n)
if done_n:
if self.steps < self.episode_limit:
info["episode_limit"] = False # the next state will be masked out
else:
info["episode_limit"] = True # the next state will not be masked out
obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
return obs, reward_n, done_n, info
def get_obs(self):
""" Returns all agent observat3ions in a list """
obs_n = []
for a in range(self.n_agents):
obs_n.append(self.get_obs_agent(a))
return np.array(obs_n).astype(np.float32)
def get_obs_agent(self, agent_id):
if self.agent_obsk is None:
return self.env._get_obs()
else:
return build_obs(
self.env,
self.k_dicts[agent_id],
self.k_categories,
self.mujoco_globals,
self.global_categories,
vec_len=getattr(self, "obs_size", None)
)
def get_obs_size(self):
""" Returns the shape of the observation """
if self.agent_obsk is None:
return self.get_obs_agent(0).size
else:
return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)])
def get_state(self, team=None):
# TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.)
state_n = []
if self.add_agent_id:
state = self.env._get_obs()
for a in range(self.n_agents):
agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)
agent_id_feats[a] = 1.0
state_i = np.concatenate([state, agent_id_feats])
state_n.append(state_i)
else:
for a in range(self.n_agents):
state_n.append(self.env._get_obs())
return np.array(state_n).astype(np.float32)
def get_state_size(self):
""" Returns the shape of the state"""
return len(self.get_state())
def get_avail_actions(self): # all actions are always available
return np.ones(shape=(
self.n_agents,
self.n_actions,
))
def get_avail_agent_actions(self, agent_id):
""" Returns the available actions for agent_id """
return np.ones(shape=(self.n_actions, ))
def get_total_actions(self):
""" Returns the total number of actions an agent could ever take """
return self.n_actions # CAREFUL! - for continuous dims, this is action space dim rather
# return self.env.action_space.shape[0]
def get_stats(self):
return {}
# TODO: Temp hack
def get_agg_stats(self, stats):
return {}
def reset(self, **kwargs):
""" Returns initial observations and states"""
self.steps = 0
self.timelimit_env.reset()
obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
return obs
def render(self, **kwargs):
self.env.render(**kwargs)
def close(self):
pass
#raise NotImplementedError
def seed(self, args):
pass
def get_env_info(self):
env_info = {
"state_shape": self.get_state_size(),
"obs_shape": self.get_obs_size(),
"n_actions": self.get_total_actions(),
"n_agents": self.n_agents,
"episode_limit": self.episode_limit,
"action_spaces": self.action_space,
"actions_dtype": np.float32,
"normalise_actions": False
}
return env_info
from typing import Any, Union, List
import copy
import numpy as np
from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
from ding.envs.common.env_element import EnvElement, EnvElementInfo
from ding.envs.common.common_function import affine_transform
from ding.torch_utils import to_ndarray, to_list
from .mujoco_multi import MujocoMulti
from ding.utils import ENV_REGISTRY
@ENV_REGISTRY.register('mujoco_multi')
class MujocoEnv(BaseEnv):
def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._init_flag = False
def reset(self) -> np.ndarray:
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._cfg.seed = self._seed + np_seed
elif hasattr(self, '_seed'):
self._cfg.seed = self._seed
if not self._init_flag:
self._env = MujocoMulti(env_args=self._cfg)
self._init_flag = True
obs = self._env.reset()
#print(obs)
#obs['agent_state'] = to_ndarray(obs['agent_state']).astype('float32')
#obs['global_state'] = to_ndarray(obs['global_state']).astype('float32')
self._final_eval_reward = 0.
return obs
def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
action = to_ndarray(action)
obs, rew, done, info = self._env.step(action)
self._final_eval_reward += rew
#obs = to_ndarray(obs).astype('float32')
rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
if done:
info['final_eval_reward'] = self._final_eval_reward
return BaseEnvTimestep(obs, rew, done, info)
def info(self) -> BaseEnvInfo:
env_info = self._env.get_env_info()
info = BaseEnvInfo(
agent_num=env_info['n_agents'],
obs_space=EnvElementInfo(
shape={
'agent_state': env_info['obs_shape'],
'global_state': env_info['state_shape'],
},
value={
'min': np.float32("-inf"),
'max': np.float32("inf"),
'dtype': np.float32
},
),
act_space=EnvElementInfo(
shape=env_info['action_spaces'],
value={
'min': np.float32("-inf"),
'max': np.float32("inf"),
'dtype': np.float32
},
),
rew_space=EnvElementInfo(
shape=1,
value={
'min': np.float64("-inf"),
'max': np.float64("inf")
},
),
use_wrappers=None,
),
return info
def __repr__(self) -> str:
return "DI-engine Multi-agent Mujoco Env({})".format(self._cfg.env_id)
from collections import namedtuple
import numpy as np
def convert(dictionary):
return namedtuple('GenericDict', dictionary.keys())(**dictionary)
class MultiAgentEnv(object):
def __init__(self, batch_size=None, **kwargs):
# Unpack arguments from sacred
args = kwargs["env_args"]
if isinstance(args, dict):
args = convert(args)
self.args = args
if getattr(args, "seed", None) is not None:
self.seed = args.seed
self.rs = np.random.RandomState(self.seed) # initialise numpy random state
def step(self, actions):
""" Returns reward, terminated, info """
raise NotImplementedError
def get_obs(self):
""" Returns all agent observations in a list """
raise NotImplementedError
def get_obs_agent(self, agent_id):
""" Returns observation for agent_id """
raise NotImplementedError
def get_obs_size(self):
""" Returns the shape of the observation """
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def get_state_size(self):
""" Returns the shape of the state"""
raise NotImplementedError
def get_avail_actions(self):
raise NotImplementedError
def get_avail_agent_actions(self, agent_id):
""" Returns the available actions for agent_id """
raise NotImplementedError
def get_total_actions(self):
""" Returns the total number of actions an agent could ever take """
# TODO: This is only suitable for a discrete 1 dimensional action space for each agent
raise NotImplementedError
def get_stats(self):
raise NotImplementedError
# TODO: Temp hack
def get_agg_stats(self, stats):
return {}
def reset(self):
""" Returns initial observations and states"""
raise NotImplementedError
def render(self):
raise NotImplementedError
def close(self):
raise NotImplementedError
def seed(self, seed):
raise NotImplementedError
def get_env_info(self):
env_info = {
"state_shape": self.get_state_size(),
"obs_shape": self.get_obs_size(),
"n_actions": self.get_total_actions(),
"n_agents": self.n_agents,
"episode_limit": self.episode_limit
}
return env_info
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册