From f1bf66d056dacbf49003ae8fbcdc2bac87b4c34d Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 7 Sep 2021 16:21:19 +0800 Subject: [PATCH] feature(wyh): add mappo algorithm for SMAC --- ding/entry/__init__.py | 3 +- ding/entry/cli.py | 9 +- ding/entry/serial_entry_onpolicy.py | 91 ++++++ ding/model/template/__init__.py | 1 + ding/model/template/mappo.py | 251 ++++++++++++++++ ding/policy/ppo.py | 10 +- ding/rl_utils/adder.py | 2 +- ding/rl_utils/gae.py | 8 +- ding/rl_utils/ppo.py | 8 +- ding/rl_utils/tests/test_adder.py | 46 +++ ding/rl_utils/tests/test_gae.py | 11 + ding/utils/data/collate_fn.py | 5 +- ding/utils/default_helper.py | 32 +- dizoo/smac/config/smac_3s5z_mappo_config.py | 91 ++++++ dizoo/smac/config/smac_5m6m_mappo_config.py | 90 ++++++ dizoo/smac/config/smac_MMM2_mappo_config.py | 89 ++++++ dizoo/smac/config/smac_MMM_mappo_config.py | 90 ++++++ dizoo/smac/envs/smac_env.py | 312 ++++++++++++++++++-- 18 files changed, 1112 insertions(+), 37 deletions(-) create mode 100644 ding/entry/serial_entry_onpolicy.py create mode 100644 ding/model/template/mappo.py create mode 100644 dizoo/smac/config/smac_3s5z_mappo_config.py create mode 100644 dizoo/smac/config/smac_5m6m_mappo_config.py create mode 100644 dizoo/smac/config/smac_MMM2_mappo_config.py create mode 100644 dizoo/smac/config/smac_MMM_mappo_config.py diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index 9b27566..53fb990 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -1,7 +1,8 @@ from .cli import cli from .serial_entry import serial_pipeline +from .serial_entry_onpolicy import serial_pipeline_onpolicy +from .serial_entry_offline import serial_pipeline_offline from .serial_entry_il import serial_pipeline_il from .serial_entry_reward_model import serial_pipeline_reward_model from .parallel_entry import parallel_pipeline from .application_entry import eval, collect_demo_data -from .serial_entry_offline import serial_pipeline_offline diff --git a/ding/entry/cli.py b/ding/entry/cli.py index b47f209..cc9df3e 100644 --- a/ding/entry/cli.py +++ b/ding/entry/cli.py @@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) @click.option( '-m', '--mode', - type=click.Choice(['serial', 'serial_sqil', 'parallel', 'dist', 'eval']), + type=click.Choice(['serial', 'serial_onpolicy', 'serial_sqil', 'parallel', 'dist', 'eval']), help='serial-train or parallel-train or dist-train or eval' ) @click.option('-c', '--config', type=str, help='Path to DRL experiment config') @@ -144,7 +144,12 @@ def cli( if config is None: config = get_predefined_config(env, policy) serial_pipeline(config, seed, max_iterations=train_iter) - if mode == 'serial_sqil': + elif mode == 'serial_onpolicy': + from .serial_entry_onpolicy import serial_pipeline_onpolicy + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline_onpolicy(config, seed, max_iterations=train_iter) + elif mode == 'serial_sqil': if config == 'lunarlander_sqil_config.py' or 'cartpole_sqil_config.py' or 'pong_sqil_config.py' \ or 'spaceinvaders_sqil_config.py' or 'qbert_sqil_config.py': from .serial_entry_sqil import serial_pipeline_sqil diff --git a/ding/entry/serial_entry_onpolicy.py b/ding/entry/serial_entry_onpolicy.py new file mode 100644 index 0000000..3409697 --- /dev/null +++ b/ding/entry/serial_entry_onpolicy.py @@ -0,0 +1,91 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +import logging +from functools import partial +from tensorboardX import SummaryWriter + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseSerialCommander, create_buffer, \ + create_serial_collector +from ding.config import read_config, compile_config +from ding.policy import create_policy, PolicyFactory +from ding.utils import set_pkg_seed + + +def serial_pipeline_onpolicy( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_iterations: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline entry for onpolicy algorithm(such as PPO). + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \ + when reaching this iteration. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = input_cfg + create_cfg.policy.type = create_cfg.policy.type + '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = create_serial_collector( + cfg.policy.collect.collector, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + evaluator = BaseSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook('before_run') + + # Accumulate plenty of data at the beginning of training. + for _ in range(max_iterations): + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Collect data by default config n_sample/n_episode + new_data = collector.collect(train_iter=learner.train_iter) + # Learn policy from collected data + learner.train(new_data, collector.envstep) + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index ebffecf..3027648 100644 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -10,3 +10,4 @@ from .atoc import ATOC from .sqn import SQN from .acer import ACER from .qtran import QTran +from .mappo import MAPPO diff --git a/ding/model/template/mappo.py b/ding/model/template/mappo.py new file mode 100644 index 0000000..154aefd --- /dev/null +++ b/ding/model/template/mappo.py @@ -0,0 +1,251 @@ +from typing import Union, Dict, Optional +import torch +import torch.nn as nn + +from ding.utils import SequenceType, squeeze, MODEL_REGISTRY +from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ + FCEncoder, ConvEncoder + + +@MODEL_REGISTRY.register('mappo') +class MAPPO(nn.Module): + r""" + Overview: + The MAPPO model. + Interfaces: + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] + + def __init__( + self, + agent_obs_shape: Union[int, SequenceType], + global_obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + agent_num: int, + encoder_hidden_size_list: SequenceType = [128, 128, 64], + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 2, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + ) -> None: + r""" + Overview: + Init the VAC Model according to arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. + - action_shape (:obj:`Union[int, SequenceType]`): Action's space. + - share_encoder (:obj:`bool`): Whether share encoder. + - continuous (:obj:`bool`): Whether collect continuously. + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. + - actor_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for actor's nn. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. + - critic_head_layer_num (:obj:`int`): + The num of layers used in the network to compute Q value output for critic's nn. + - activation (:obj:`Optional[nn.Module]`): + The type of activation function to use in ``MLP`` the after ``layer_fn``, + if ``None`` then default set to ``nn.ReLU()`` + - norm_type (:obj:`Optional[str]`): + The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` + """ + super(MAPPO, self).__init__() + agent_obs_shape: int = squeeze(agent_obs_shape) + global_obs_shape: int = squeeze(global_obs_shape) + action_shape: int = squeeze(action_shape) + self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape + # Encoder Type + if isinstance(agent_obs_shape, int) or len(agent_obs_shape) == 1: + encoder_cls = FCEncoder + elif len(agent_obs_shape) == 3: + encoder_cls = ConvEncoder + else: + raise RuntimeError( + "not support obs_shape for pre-defined encoder: {}, please customize your own DQN". + format(agent_obs_shape) + ) + if isinstance(global_obs_shape, int) or len(global_obs_shape) == 1: + global_encoder_cls = FCEncoder + elif len(global_obs_shape) == 3: + global_encoder_cls = ConvEncoder + else: + raise RuntimeError( + "not support obs_shape for pre-defined encoder: {}, please customize your own DQN". + format(global_obs_shape) + ) + + self.actor_encoder = encoder_cls( + agent_obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type + ) + self.critic_encoder = global_encoder_cls( + global_obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type + ) + # Head Type + self.critic_head = RegressionHead( + critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type + ) + + actor_head_cls = DiscreteHead + self.actor_head = actor_head_cls( + actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type + ) + # must use list, not nn.ModuleList + self.actor = [self.actor_encoder, self.actor_head] + self.critic = [self.critic_encoder, self.critic_head] + # for convenience of call some apis(such as: self.critic.parameters()), but may cause + # misunderstanding when print(self) + self.actor = nn.ModuleList(self.actor) + self.critic = nn.ModuleList(self.critic) + + def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: + r""" + Overview: + Use encoded embedding tensor to predict output. + Parameter updates with VAC's MLPs forward setup. + Arguments: + Forward with ``'compute_actor'`` or ``'compute_critic'``: + - inputs (:obj:`torch.Tensor`): + The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. + Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. + Returns: + - outputs (:obj:`Dict`): + Run with encoder and head. + + Forward with ``'compute_actor'``, Necessary Keys: + - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. + + Forward with ``'compute_critic'``, Necessary Keys: + - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + Shapes: + - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size`` + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` + - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + + Actor Examples: + >>> model = VAC(64,128) + >>> inputs = torch.randn(4, 64) + >>> actor_outputs = model(inputs,'compute_actor') + >>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) + + Critic Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> critic_outputs = model(inputs,'compute_critic') + >>> critic_outputs['value'] + tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + + Actor-Critic Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> outputs = model(inputs,'compute_actor_critic') + >>> outputs['value'] + tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + >>> assert outputs['logit'].shape == torch.Size([4, 64]) + + """ + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(inputs) + + def compute_actor(self, x: torch.Tensor) -> Dict: + r""" + Overview: + Execute parameter updates with ``'compute_actor'`` mode + Use encoded embedding tensor to predict output. + Arguments: + - inputs (:obj:`torch.Tensor`): + The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. + ``hidden_size = actor_head_hidden_size`` + Returns: + - outputs (:obj:`Dict`): + Run with encoder and head. + + ReturnsKeys: + - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. + Shapes: + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` + + Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> actor_outputs = model(inputs,'compute_actor') + >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) + """ + action_mask = x['action_mask'] + x = x['agent_state'] + + x = self.actor_encoder(x) + x = self.actor_head(x) + logit = x['logit'] + logit[action_mask == 0.0] = -99999999 + return {'logit': logit, 'action_mask': action_mask} + + def compute_critic(self, x: Dict) -> Dict: + r""" + Overview: + Execute parameter updates with ``'compute_critic'`` mode + Use encoded embedding tensor to predict output. + Arguments: + - inputs (:obj:`Dict`): + The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. + ``hidden_size = critic_head_hidden_size`` + Returns: + - outputs (:obj:`Dict`): + Run with encoder and head. + + Necessary Keys: + - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + Shapes: + - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + + Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> critic_outputs = model(inputs,'compute_critic') + >>> critic_outputs['value'] + tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + """ + + x = self.critic_encoder(x['global_state']) + x = self.critic_head(x) + return {'value': x['pred']} + + def compute_actor_critic(self, x: Dict) -> Dict: + r""" + Overview: + Execute parameter updates with ``'compute_actor_critic'`` mode + Use encoded embedding tensor to predict output. + Arguments: + - inputs (:obj:`torch.Tensor`): The encoded embedding tensor. + + Returns: + - outputs (:obj:`Dict`): + Run with encoder and head. + + ReturnsKeys: + - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. + - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + Shapes: + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` + - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + + Examples: + >>> model = VAC(64,64) + >>> inputs = torch.randn(4, 64) + >>> outputs = model(inputs,'compute_actor_critic') + >>> outputs['value'] + tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + >>> assert outputs['logit'].shape == torch.Size([4, 64]) + + + .. note:: + ``compute_actor_critic`` interface aims to save computation when shares encoder. + Returning the combination dictionry. + + """ + logit = self.compute_actor(x)['logit'] + value = self.compute_critic(x)['value'] + action_mask = x['action_mask'] + return {'logit': logit, 'value': value, 'action_mask': x['action_mask']} diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 9d8620f..bbbd970 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -35,6 +35,7 @@ class PPOPolicy(Policy): priority_IS_weight=False, recompute_adv=True, continuous=True, + multi_agent=False, learn=dict( # (bool) Whether to use multi gpu multi_gpu=False, @@ -124,7 +125,7 @@ class PPOPolicy(Policy): self._adv_norm = self._cfg.learn.adv_norm self._value_norm = self._cfg.learn.value_norm if self._value_norm: - self._running_mean_std = RunningMeanStd(epsilon=1e-4) + self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device) self._gamma = self._cfg.collect.discount_factor self._gae_lambda = self._cfg.collect.gae_lambda self._recompute_adv = self._cfg.recompute_adv @@ -321,7 +322,7 @@ class PPOPolicy(Policy): data[-1]['done'] = False if data[-1]['done']: - last_value = torch.zeros(1) + last_value = torch.zeros_like(data[-1]['value']) else: with torch.no_grad(): last_value = self._collect_model.forward( @@ -382,7 +383,10 @@ class PPOPolicy(Policy): return {i: d for i, d in zip(data_id, output)} def default_model(self) -> Tuple[str, List[str]]: - return 'vac', ['ding.model.template.vac'] + if self._cfg.multi_agent: + return 'mappo', ['ding.model.template.mappo'] + else: + return 'vac', ['ding.model.template.vac'] def _monitor_vars_learn(self) -> List[str]: variables = super()._monitor_vars_learn() + [ diff --git a/ding/rl_utils/adder.py b/ding/rl_utils/adder.py index e846130..dc1bd64 100644 --- a/ding/rl_utils/adder.py +++ b/ding/rl_utils/adder.py @@ -65,7 +65,7 @@ class Adder(object): extra advantage key 'adv' """ if done: - last_value = torch.zeros(1) + last_value = torch.zeros_like(data[-1]['value']) else: last_data = data.pop() last_value = last_data['value'] diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py index f597249..0ba4e6e 100644 --- a/ding/rl_utils/gae.py +++ b/ding/rl_utils/gae.py @@ -46,11 +46,13 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F value, next_value, reward, done = data if done is None: done = torch.zeros_like(reward, device=reward.device) - + if len(value.shape) == len(reward.shape) + 1: # for some marl case: value(T, B, A), reward(T, B) + reward = reward.unsqueeze(-1) + done = done.unsqueeze(-1) delta = reward + (1 - done) * gamma * next_value - value factor = gamma * lambda_ - adv = torch.zeros_like(reward, device=reward.device) - gae_item = 0. + adv = torch.zeros_like(value, device=value.device) + gae_item = torch.zeros_like(value[0]) for t in reversed(range(reward.shape[0])): gae_item = delta[t] + factor * gae_item * (1 - done[t]) diff --git a/ding/rl_utils/ppo.py b/ding/rl_utils/ppo.py index ae162ba..e51851a 100644 --- a/ding/rl_utils/ppo.py +++ b/ding/rl_utils/ppo.py @@ -92,9 +92,14 @@ def ppo_policy_error(data: namedtuple, dist_old = torch.distributions.categorical.Categorical(logits=logit_old) logp_new = dist_new.log_prob(action) logp_old = dist_old.log_prob(action) - entropy_loss = (dist_new.entropy() * weight).mean() + dist_new_entropy = dist_new.entropy() + if dist_new_entropy.shape != weight.shape: + dist_new_entropy = dist_new.entropy().mean(dim=1) + entropy_loss = (dist_new_entropy * weight).mean() # policy_loss ratio = torch.exp(logp_new - logp_old) + if ratio.shape != adv.shape: + ratio = ratio.mean(dim=1) surr1 = ratio * adv surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv if dual_clip is not None: @@ -103,6 +108,7 @@ def ppo_policy_error(data: namedtuple, # only use dual_clip when adv < 0 policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean() else: + #policy_loss = (-torch.min(surr1, surr2) * weight).mean() policy_loss = (-torch.min(surr1, surr2) * weight).mean() with torch.no_grad(): approx_kl = (logp_old - logp_new).mean().item() diff --git a/ding/rl_utils/tests/test_adder.py b/ding/rl_utils/tests/test_adder.py index b1be649..ab7345f 100644 --- a/ding/rl_utils/tests/test_adder.py +++ b/ding/rl_utils/tests/test_adder.py @@ -18,6 +18,15 @@ class TestAdder: 'done': False } + def get_transition_multi_agent(self): + return { + 'value': torch.randn(1, 8), + 'reward': torch.rand(1, 1), + 'other': np.random.randint(0, 10, size=(4, )), + 'obs': torch.randn(3), + 'done': False + } + def test_get_gae(self): transitions = deque([self.get_transition() for _ in range(10)]) last_value = torch.randn(1) @@ -46,6 +55,39 @@ class TestAdder: for i in range(len(output)): assert output[i]['adv'].eq(output2[i]['adv']) + def test_get_gae_multi_agent(self): + transitions = deque([self.get_transition_multi_agent() for _ in range(10)]) + last_value = torch.randn(1, 8) + output = get_gae(transitions, last_value, gamma=0.99, gae_lambda=0.97, cuda=False) + for i in range(len(output)): + o = output[i] + assert 'adv' in o.keys() + for k, v in o.items(): + if k == 'adv': + assert isinstance(v, torch.Tensor) + assert v.shape == ( + 1, + 8, + ) + else: + if k == 'done': + assert v == transitions[i][k] + else: + assert (v == transitions[i][k]).all() + output1 = get_gae_with_default_last_value( + copy.deepcopy(transitions), True, gamma=0.99, gae_lambda=0.97, cuda=False + ) + for i in range(len(output)): + for j in range(output[i]['adv'].shape[1]): + assert output[i]['adv'][0][j].ne(output1[i]['adv'][0][j]) + + data = copy.deepcopy(transitions) + data.append({'value': last_value}) + output2 = get_gae_with_default_last_value(data, False, gamma=0.99, gae_lambda=0.97, cuda=False) + for i in range(len(output)): + for j in range(output[i]['adv'].shape[1]): + assert output[i]['adv'][0][j].eq(output2[i]['adv'][0][j]) + def test_get_nstep_return_data(self): nstep = 3 data = deque([self.get_transition() for _ in range(10)]) @@ -96,3 +138,7 @@ class TestAdder: assert output[-1]['done'][-1] is True assert output[-1]['done'][0] is False assert id(output[-1]['obs'][-1]) != id(output[-1]['obs'][0]) + + +test = TestAdder() +test.test_get_gae_multi_agent() diff --git a/ding/rl_utils/tests/test_gae.py b/ding/rl_utils/tests/test_gae.py index 0d1e852..d8bcbf7 100644 --- a/ding/rl_utils/tests/test_gae.py +++ b/ding/rl_utils/tests/test_gae.py @@ -13,3 +13,14 @@ def test_gae(): data = gae_data(value, next_value, reward, done) adv = gae(data) assert adv.shape == (T, B) + + +def test_gae_multi_agent(): + T, B, A = 32, 4, 8 + value = torch.randn(T, B, A) + next_value = torch.randn(T, B, A) + reward = torch.randn(T, B) + done = torch.zeros(T, B) + data = gae_data(value, next_value, reward, done) + adv = gae(data) + assert adv.shape == (T, B, A) diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index 4e85e55..702a24f 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -165,7 +165,10 @@ def diff_shape_collate(batch: Sequence) -> Union[torch.Tensor, Mapping, Sequence raise TypeError('not support element type: {}'.format(elem_type)) -def default_decollate(batch: Union[torch.Tensor, Sequence, Mapping], ignore: List[str] = ['prev_state']) -> List[Any]: +def default_decollate( + batch: Union[torch.Tensor, Sequence, Mapping], + ignore: List[str] = ['prev_state', 'prev_actor_state', 'prev_critic_state'] +) -> List[Any]: """ Overview: Drag out batch_size collated data's batch size to decollate it, diff --git a/ding/utils/default_helper.py b/ding/utils/default_helper.py index 92b139a..cf97a2b 100644 --- a/ding/utils/default_helper.py +++ b/ding/utils/default_helper.py @@ -3,7 +3,6 @@ import logging import random from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any from functools import lru_cache # in python3.9, we can change to cache - import numpy as np import torch @@ -414,11 +413,15 @@ def one_time_warning(warning_msg: str) -> None: def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict: assert isinstance(data, dict), type(data) length = [] - for v in data.values(): + for k, v in data.items(): if v is None: continue + elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']: + length.append(len(v)) elif isinstance(v, list) or isinstance(v, tuple): length.append(len(v[0])) + elif isinstance(v, dict): + length.append(len(v[list(v.keys())[0]])) else: length.append(len(v)) assert len(length) > 0 @@ -436,8 +439,12 @@ def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> d for k in data.keys(): if data[k] is None: batch[k] = None + elif k.startswith('prev_state'): + batch[k] = [data[k][t] for t in indices[i:i + split_size]] elif isinstance(data[k], list) or isinstance(data[k], tuple): batch[k] = [t[indices[i:i + split_size]] for t in data[k]] + elif isinstance(data[k], dict): + batch[k] = {k1: v1[indices[i:i + split_size]] for k1, v1 in data[k].items()} else: batch[k] = data[k][indices[i:i + split_size]] yield batch @@ -453,7 +460,7 @@ class RunningMeanStd(object): - ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count`` """ - def __init__(self, epsilon=1e-4, shape=()): + def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')): """ Overview: Initialize ``self.`` See ``help(type(self))`` for accurate \ @@ -466,6 +473,7 @@ class RunningMeanStd(object): """ self._epsilon = epsilon self._shape = shape + self._device = device self.reset() def update(self, x): @@ -496,8 +504,11 @@ class RunningMeanStd(object): Overview: Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count`` """ - self._mean = np.zeros(self._shape, 'float32') - self._var = np.ones(self._shape, 'float32') + if len(self._shape) > 0: + self._mean = np.zeros(self._shape, 'float32') + self._var = np.ones(self._shape, 'float32') + else: + self._mean, self._var = 0., 1. self._count = self._epsilon @property @@ -506,7 +517,10 @@ class RunningMeanStd(object): Overview: Property ``mean`` gotten from ``self._mean`` """ - return self._mean + if np.isscalar(self._mean): + return self._mean + else: + return torch.FloatTensor(self._mean).to(self._device) @property def std(self) -> np.ndarray: @@ -514,7 +528,11 @@ class RunningMeanStd(object): Overview: Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon`` """ - return np.sqrt(self._var + 1e-8) + std = np.sqrt(self._var + 1e-8) + if np.isscalar(std): + return std + else: + return torch.FloatTensor(std).to(self._device) @staticmethod def new_shape(obs_shape, act_shape, rew_shape): diff --git a/dizoo/smac/config/smac_3s5z_mappo_config.py b/dizoo/smac/config/smac_3s5z_mappo_config.py new file mode 100644 index 0000000..23c8c39 --- /dev/null +++ b/dizoo/smac/config/smac_3s5z_mappo_config.py @@ -0,0 +1,91 @@ +import sys +from copy import deepcopy +from ding.entry import serial_pipeline +from easydict import EasyDict + +agent_num = 8 +collector_env_num = 8 +evaluator_env_num = 8 +special_global_state = True + +main_config = dict( + exp_name='smac_3s5z_ppo', + env=dict( + map_name='3s5z', + difficulty=7, + reward_only_positive=True, + mirror_opponent=False, + agent_num=agent_num, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=16, + stop_value=0.99, + death_mask=False, + special_global_state=special_global_state, + # save_replay_episodes = 1, + manager=dict( + shared_memory=False, + reset_timeout=6000, + ), + ), + policy=dict( + cuda=True, + multi_agent=True, + continuous=False, + model=dict( + # (int) agent_num: The number of the agent. + # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2. + agent_num=agent_num, + # (int) obs_shape: The shapeension of observation of each agent. + # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404. + # (int) global_obs_shape: The shapeension of global observation. + # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342. + agent_obs_shape=150, + #global_obs_shape=216, + global_obs_shape=295, + # (int) action_shape: The number of action which each agent can take. + # action_shape= the number of common action (6) + the number of enemies. + # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64). + action_shape=14, + # (List[int]) The size of hidden layer + # hidden_size_list=[64], + ), + # used in state_num of hidden_state + learn=dict( + # (bool) Whether to use multi gpu + multi_gpu=False, + epoch_per_collect=5, + batch_size=3200, + learning_rate=5e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.01, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=False, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=10, + ignore_done=False, + ), + on_policy=True, + collect=dict(env_num=collector_env_num, n_sample=3200), + eval=dict(env_num=evaluator_env_num), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + type='smac', + import_names=['dizoo.smac.envs.smac_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo'), +) +create_config = EasyDict(create_config) diff --git a/dizoo/smac/config/smac_5m6m_mappo_config.py b/dizoo/smac/config/smac_5m6m_mappo_config.py new file mode 100644 index 0000000..3ac1840 --- /dev/null +++ b/dizoo/smac/config/smac_5m6m_mappo_config.py @@ -0,0 +1,90 @@ +import sys +from copy import deepcopy +from ding.entry import serial_pipeline +from easydict import EasyDict + +agent_num = 5 +collector_env_num = 8 +evaluator_env_num = 8 +special_global_state = True, + +main_config = dict( + exp_name='smac_5m6m_ppo', + env=dict( + map_name='5m_vs_6m', + difficulty=7, + reward_only_positive=True, + mirror_opponent=False, + agent_num=agent_num, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=16, + stop_value=0.99, + death_mask=True, + special_global_state=special_global_state, + manager=dict( + shared_memory=False, + reset_timeout=6000, + ), + ), + policy=dict( + cuda=True, + multi_agent=True, + continuous=False, + model=dict( + # (int) agent_num: The number of the agent. + # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2. + agent_num=agent_num, + # (int) obs_shape: The shapeension of observation of each agent. + # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404. + # (int) global_obs_shape: The shapeension of global observation. + # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342. + agent_obs_shape=72, + #global_obs_shape=216, + global_obs_shape=152, + # (int) action_shape: The number of action which each agent can take. + # action_shape= the number of common action (6) + the number of enemies. + # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64). + action_shape=12, + # (List[int]) The size of hidden layer + # hidden_size_list=[64], + ), + # used in state_num of hidden_state + learn=dict( + # (bool) Whether to use multi gpu + multi_gpu=False, + epoch_per_collect=10, + batch_size=3200, + learning_rate=5e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.01, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.05, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=False, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=10, + ignore_done=False, + ), + on_policy=True, + collect=dict(env_num=collector_env_num, n_sample=3200), + eval=dict(env_num=evaluator_env_num), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + type='smac', + import_names=['dizoo.smac.envs.smac_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo'), +) +create_config = EasyDict(create_config) diff --git a/dizoo/smac/config/smac_MMM2_mappo_config.py b/dizoo/smac/config/smac_MMM2_mappo_config.py new file mode 100644 index 0000000..81da4f4 --- /dev/null +++ b/dizoo/smac/config/smac_MMM2_mappo_config.py @@ -0,0 +1,89 @@ +import sys +from copy import deepcopy +from ding.entry import serial_pipeline +from easydict import EasyDict + +agent_num = 10 +collector_env_num = 8 +evaluator_env_num = 8 +special_global_state = True + +main_config = dict( + exp_name='smac_MMM2_ppo', + env=dict( + map_name='MMM2', + difficulty=7, + reward_only_positive=True, + mirror_opponent=False, + agent_num=agent_num, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=16, + stop_value=0.99, + death_mask=True, + special_global_state=special_global_state, + manager=dict( + shared_memory=False, + reset_timeout=6000, + ), + ), + policy=dict( + cuda=True, + multi_agent=True, + continuous=False, + model=dict( + # (int) agent_num: The number of the agent. + # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2. + agent_num=agent_num, + # (int) obs_shape: The shapeension of observation of each agent. + # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404. + # (int) global_obs_shape: The shapeension of global observation. + # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342. + agent_obs_shape=204, + global_obs_shape=431, + # (int) action_shape: The number of action which each agent can take. + # action_shape= the number of common action (6) + the number of enemies. + # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64). + action_shape=18, + # (List[int]) The size of hidden layer + # hidden_size_list=[64], + ), + # used in state_num of hidden_state + learn=dict( + # (bool) Whether to use multi gpu + multi_gpu=False, + epoch_per_collect=5, + batch_size=1600, + learning_rate=5e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.01, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=False, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=10, + ignore_done=False, + ), + on_policy=True, + collect=dict(env_num=collector_env_num, n_sample=3200), + eval=dict(env_num=evaluator_env_num), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + type='smac', + import_names=['dizoo.smac.envs.smac_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo'), +) +create_config = EasyDict(create_config) diff --git a/dizoo/smac/config/smac_MMM_mappo_config.py b/dizoo/smac/config/smac_MMM_mappo_config.py new file mode 100644 index 0000000..2de1cda --- /dev/null +++ b/dizoo/smac/config/smac_MMM_mappo_config.py @@ -0,0 +1,90 @@ +import sys +from copy import deepcopy +from ding.entry import serial_pipeline +from easydict import EasyDict + +agent_num = 10 +collector_env_num = 8 +evaluator_env_num = 8 +special_global_state = True, + +main_config = dict( + exp_name='smac_MMM_ppo', + env=dict( + map_name='MMM', + difficulty=7, + reward_only_positive=True, + mirror_opponent=False, + agent_num=agent_num, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=16, + stop_value=0.99, + death_mask=False, + special_global_state=special_global_state, + manager=dict( + shared_memory=False, + reset_timeout=6000, + ), + ), + policy=dict( + cuda=True, + multi_agent=True, + continuous=False, + model=dict( + # (int) agent_num: The number of the agent. + # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2. + agent_num=agent_num, + # (int) obs_shape: The shapeension of observation of each agent. + # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404. + # (int) global_obs_shape: The shapeension of global observation. + # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342. + agent_obs_shape=186, + #global_obs_shape=216, + global_obs_shape=389, + # (int) action_shape: The number of action which each agent can take. + # action_shape= the number of common action (6) + the number of enemies. + # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64). + action_shape=16, + # (List[int]) The size of hidden layer + # hidden_size_list=[64], + ), + # used in state_num of hidden_state + learn=dict( + # (bool) Whether to use multi gpu + multi_gpu=False, + epoch_per_collect=5, + batch_size=320, + learning_rate=5e-4, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.01, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=False, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=10, + ignore_done=False, + ), + on_policy=True, + collect=dict(env_num=collector_env_num, n_sample=3200), + eval=dict(env_num=evaluator_env_num), + ), +) +main_config = EasyDict(main_config) +create_config = dict( + env=dict( + type='smac', + import_names=['dizoo.smac.envs.smac_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo'), +) +create_config = EasyDict(create_config) diff --git a/dizoo/smac/envs/smac_env.py b/dizoo/smac/envs/smac_env.py index c299b46..d521e4b 100644 --- a/dizoo/smac/envs/smac_env.py +++ b/dizoo/smac/envs/smac_env.py @@ -4,6 +4,7 @@ from collections import namedtuple from operator import attrgetter import numpy as np +import math from easydict import EasyDict import pysc2.env.sc2_env as sc2_env from pysc2.env.sc2_env import SC2Env @@ -65,6 +66,12 @@ class SMACEnv(SC2Env, BaseEnv): obs_alone=False, game_steps_per_episode=None, reward_only_positive=True, + death_mask=False, + special_global_state=False, + # add map's center location ponit or not + add_center_xy=True, + # add agent's id information or not in special global state + state_agent_id=True, ) def __init__( @@ -126,6 +133,11 @@ class SMACEnv(SC2Env, BaseEnv): self.hydralisk_id = self.zergling_id = self.baneling_id = 0 self.stalker_id = self.colossus_id = self.zealot_id = 0 + self.add_center_xy = cfg.add_center_xy + self.state_agent_id = cfg.state_agent_id + self.death_mask = cfg.death_mask + self.special_global_state = cfg.special_global_state + # reward self.reward_death_value = cfg.reward_death_value self.reward_win = cfg.reward_win @@ -351,11 +363,18 @@ class SMACEnv(SC2Env, BaseEnv): 'action_mask': self.get_avail_actions() } else: - return { - 'agent_state': self.get_obs(), - 'global_state': self.get_state(), - 'action_mask': self.get_avail_actions() - } + if self.special_global_state: + return { + 'agent_state': self.get_obs(), + 'global_state': self.get_global_special_state(), + 'action_mask': self.get_avail_actions(), + } + else: + return { + 'agent_state': self.get_obs(), + 'global_state': self.get_state(), + 'action_mask': self.get_avail_actions(), + } return { 'agent_state': { @@ -451,11 +470,18 @@ class SMACEnv(SC2Env, BaseEnv): 'action_mask': self.get_avail_actions() } else: - obs = { - 'agent_state': self.get_obs(), - 'global_state': self.get_state(), - 'action_mask': self.get_avail_actions() - } + if self.special_global_state: + obs = { + 'agent_state': self.get_obs(), + 'global_state': self.get_global_special_state(), + 'action_mask': self.get_avail_actions(), + } + else: + obs = { + 'agent_state': self.get_obs(), + 'global_state': self.get_state(), + 'action_mask': self.get_avail_actions(), + } else: raise NotImplementedError @@ -1174,6 +1200,246 @@ class SMACEnv(SC2Env, BaseEnv): state = self._flatten_state(state) return np.array(state).astype(np.float32) + def get_global_special_state(self, is_opponent=False): + """Returns all agent observations in a list. + NOTE: Agents should have access only to their local observations + during decentralised execution. + """ + agents_obs_list = [self.get_state_agent(i, is_opponent) for i in range(self.n_agents)] + + return np.array(agents_obs_list).astype(np.float32) + + def get_global_special_state_size(self, is_opponent=False): + enemy_feats_dim = self.get_state_enemy_feats_size() + ally_feats_dim = self.get_state_ally_feats_size() + own_feats_dim = self.get_state_own_feats_size() + size = enemy_feats_dim + ally_feats_dim + own_feats_dim + self.n_agents + if self.state_timestep_number: + size += 1 + return size + + def get_state_agent(self, agent_id, is_opponent=False): + """Returns observation for agent_id. The observation is composed of: + + - agent movement features (where it can move to, height information and pathing grid) + - enemy features (available_to_attack, health, relative_x, relative_y, shield, unit_type) + - ally features (visible, distance, relative_x, relative_y, shield, unit_type) + - agent unit features (health, shield, unit_type) + + All of this information is flattened and concatenated into a list, + in the aforementioned order. To know the sizes of each of the + features inside the final list of features, take a look at the + functions ``get_obs_move_feats_size()``, + ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and + ``get_obs_own_feats_size()``. + + The size of the observation vector may vary, depending on the + environment configuration and type of units present in the map. + For instance, non-Protoss units will not have shields, movement + features may or may not include terrain height and pathing grid, + unit_type is not included if there is only one type of unit in the + map etc.). + + NOTE: Agents should have access only to their local observations + during decentralised execution. + """ + if self.obs_instead_of_state: + obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32) + return obs_concat + + unit = self.get_unit_by_id(agent_id) + + enemy_feats_dim = self.get_state_enemy_feats_size() + ally_feats_dim = self.get_state_ally_feats_size() + own_feats_dim = self.get_state_own_feats_size() + + enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32) + ally_feats = np.zeros(ally_feats_dim, dtype=np.float32) + own_feats = np.zeros(own_feats_dim, dtype=np.float32) + agent_id_feats = np.zeros(self.n_agents, dtype=np.float32) + + center_x = self.map_x / 2 + center_y = self.map_y / 2 + + if (self.death_mask and unit.health > 0) or (not self.death_mask): # otherwise dead, return all zeros + x = unit.pos.x + y = unit.pos.y + sight_range = self.unit_sight_range(agent_id) + last_action = self.action_helper.get_last_action(is_opponent) + + # Movement features + avail_actions = self.get_avail_agent_actions(agent_id) + + # Enemy features + for e_id, e_unit in self.enemies.items(): + e_x = e_unit.pos.x + e_y = e_unit.pos.y + dist = self.distance(x, y, e_x, e_y) + + if e_unit.health > 0: # visible and alive + # Sight range > shoot range + if unit.health > 0: + enemy_feats[e_id, 0] = avail_actions[self.action_helper.n_actions_no_attack + e_id] # available + enemy_feats[e_id, 1] = dist / sight_range # distance + enemy_feats[e_id, 2] = (e_x - x) / sight_range # relative X + enemy_feats[e_id, 3] = (e_y - y) / sight_range # relative Y + if dist < sight_range: + enemy_feats[e_id, 4] = 1 # visible + + ind = 5 + if self.obs_all_health: + enemy_feats[e_id, ind] = (e_unit.health / e_unit.health_max) # health + ind += 1 + if self.shield_bits_enemy > 0: + max_shield = self.unit_max_shield(e_unit) + enemy_feats[e_id, ind] = (e_unit.shield / max_shield) # shield + ind += 1 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(e_unit, False) + enemy_feats[e_id, ind + type_id] = 1 # unit type + ind += self.unit_type_bits + + if self.add_center_xy: + enemy_feats[e_id, ind] = (e_x - center_x) / self.max_distance_x # center X + enemy_feats[e_id, ind + 1] = (e_y - center_y) / self.max_distance_y # center Y + + # Ally features + al_ids = [al_id for al_id in range(self.n_agents) if al_id != agent_id] + for i, al_id in enumerate(al_ids): + + al_unit = self.get_unit_by_id(al_id) + al_x = al_unit.pos.x + al_y = al_unit.pos.y + dist = self.distance(x, y, al_x, al_y) + max_cd = self.unit_max_cooldown(al_unit) + + if al_unit.health > 0: # visible and alive + if unit.health > 0: + if dist < sight_range: + ally_feats[i, 0] = 1 # visible + ally_feats[i, 1] = dist / sight_range # distance + ally_feats[i, 2] = (al_x - x) / sight_range # relative X + ally_feats[i, 3] = (al_y - y) / sight_range # relative Y + + if (self.map_type == "MMM" and al_unit.unit_type == self.medivac_id): + ally_feats[i, 4] = al_unit.energy / max_cd # energy + else: + ally_feats[i, 4] = (al_unit.weapon_cooldown / max_cd) # cooldown + + ind = 5 + if self.obs_all_health: + ally_feats[i, ind] = (al_unit.health / al_unit.health_max) # health + ind += 1 + if self.shield_bits_ally > 0: + max_shield = self.unit_max_shield(al_unit) + ally_feats[i, ind] = (al_unit.shield / max_shield) # shield + ind += 1 + + if self.add_center_xy: + ally_feats[i, ind] = (al_x - center_x) / self.max_distance_x # center X + ally_feats[i, ind + 1] = (al_y - center_y) / self.max_distance_y # center Y + ind += 2 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(al_unit, True) + ally_feats[i, ind + type_id] = 1 + ind += self.unit_type_bits + + if self.state_last_action: + ally_feats[i, ind:] = last_action[al_id] + + # Own features + ind = 0 + own_feats[0] = 1 # visible + own_feats[1] = 0 # distance + own_feats[2] = 0 # X + own_feats[3] = 0 # Y + ind = 4 + if self.obs_own_health: + own_feats[ind] = unit.health / unit.health_max + ind += 1 + if self.shield_bits_ally > 0: + max_shield = self.unit_max_shield(unit) + own_feats[ind] = unit.shield / max_shield + ind += 1 + + if self.add_center_xy: + own_feats[ind] = (x - center_x) / self.max_distance_x # center X + own_feats[ind + 1] = (y - center_y) / self.max_distance_y # center Y + ind += 2 + + if self.unit_type_bits > 0: + type_id = self.get_unit_type_id(unit, True) + own_feats[ind + type_id] = 1 + ind += self.unit_type_bits + + if self.state_last_action: + own_feats[ind:] = last_action[agent_id] + + state = np.concatenate((ally_feats.flatten(), enemy_feats.flatten(), own_feats.flatten())) + + # Agent id features + if self.state_agent_id: + agent_id_feats[agent_id] = 1. + state = np.append(state, agent_id_feats.flatten()) + + if self.state_timestep_number: + state = np.append(state, self._episode_steps / self.episode_limit) + + return state + + def get_state_enemy_feats_size(self): + """ Returns the dimensions of the matrix containing enemy features. + Size is n_enemies x n_features. + """ + nf_en = 5 + self.unit_type_bits + + if self.obs_all_health: + nf_en += 1 + self.shield_bits_enemy + + if self.add_center_xy: + nf_en += 2 + + return self.n_enemies, nf_en + + def get_state_ally_feats_size(self): + """Returns the dimensions of the matrix containing ally features. + Size is n_allies x n_features. + """ + nf_al = 5 + self.unit_type_bits + + if self.obs_all_health: + nf_al += 1 + self.shield_bits_ally + + if self.state_last_action: + nf_al += self.n_actions + + if self.add_center_xy: + nf_al += 2 + + return self.n_agents - 1, nf_al + + def get_state_own_feats_size(self): + """Returns the size of the vector containing the agents' own features. + """ + own_feats = 4 + self.unit_type_bits + if self.obs_own_health: + own_feats += 1 + self.shield_bits_ally + + if self.state_last_action: + own_feats += self.n_actions + + if self.add_center_xy: + own_feats += 2 + + return own_feats + + @staticmethod + def distance(x1, y1, x2, y2): + """Distance between two points.""" + return math.hypot(x2 - x1, y2 - y1) + def unit_max_cooldown(self, unit, is_opponent=False): """Returns the maximal cooldown for a unit.""" if is_opponent: @@ -1329,14 +1595,24 @@ class SMACEnv(SC2Env, BaseEnv): None, ) else: - obs_space = T( - { - 'agent_state': (agent_num, self.get_obs_size(is_opponent)), - 'global_state': (self.get_state_size(is_opponent), ), - 'action_mask': (agent_num, *self.action_helper.info().shape), - }, - None, - ) + if self.special_global_state: + obs_space = T( + { + 'agent_state': (agent_num, self.get_obs_size(is_opponent)), + 'global_state': (agent_num, self.get_global_special_state_size(is_opponent)), + 'action_mask': (agent_num, *self.action_helper.info().shape), + }, + None, + ) + else: + obs_space = T( + { + 'agent_state': (agent_num, self.get_obs_size(is_opponent)), + 'global_state': (self.get_state_size(is_opponent), ), + 'action_mask': (agent_num, *self.action_helper.info().shape), + }, + None, + ) return self.SMACEnvInfo( agent_num=agent_num, obs_space=obs_space, -- GitLab