Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
f1bf66d0
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
65
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f1bf66d0
编写于
9月 07, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feature(wyh): add mappo algorithm for SMAC
上级
69828ed5
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
1112 addition
and
37 deletion
+1112
-37
ding/entry/__init__.py
ding/entry/__init__.py
+2
-1
ding/entry/cli.py
ding/entry/cli.py
+7
-2
ding/entry/serial_entry_onpolicy.py
ding/entry/serial_entry_onpolicy.py
+91
-0
ding/model/template/__init__.py
ding/model/template/__init__.py
+1
-0
ding/model/template/mappo.py
ding/model/template/mappo.py
+251
-0
ding/policy/ppo.py
ding/policy/ppo.py
+7
-3
ding/rl_utils/adder.py
ding/rl_utils/adder.py
+1
-1
ding/rl_utils/gae.py
ding/rl_utils/gae.py
+5
-3
ding/rl_utils/ppo.py
ding/rl_utils/ppo.py
+7
-1
ding/rl_utils/tests/test_adder.py
ding/rl_utils/tests/test_adder.py
+46
-0
ding/rl_utils/tests/test_gae.py
ding/rl_utils/tests/test_gae.py
+11
-0
ding/utils/data/collate_fn.py
ding/utils/data/collate_fn.py
+4
-1
ding/utils/default_helper.py
ding/utils/default_helper.py
+25
-7
dizoo/smac/config/smac_3s5z_mappo_config.py
dizoo/smac/config/smac_3s5z_mappo_config.py
+91
-0
dizoo/smac/config/smac_5m6m_mappo_config.py
dizoo/smac/config/smac_5m6m_mappo_config.py
+90
-0
dizoo/smac/config/smac_MMM2_mappo_config.py
dizoo/smac/config/smac_MMM2_mappo_config.py
+89
-0
dizoo/smac/config/smac_MMM_mappo_config.py
dizoo/smac/config/smac_MMM_mappo_config.py
+90
-0
dizoo/smac/envs/smac_env.py
dizoo/smac/envs/smac_env.py
+294
-18
未找到文件。
ding/entry/__init__.py
浏览文件 @
f1bf66d0
from
.cli
import
cli
from
.cli
import
cli
from
.serial_entry
import
serial_pipeline
from
.serial_entry
import
serial_pipeline
from
.serial_entry_onpolicy
import
serial_pipeline_onpolicy
from
.serial_entry_offline
import
serial_pipeline_offline
from
.serial_entry_il
import
serial_pipeline_il
from
.serial_entry_il
import
serial_pipeline_il
from
.serial_entry_reward_model
import
serial_pipeline_reward_model
from
.serial_entry_reward_model
import
serial_pipeline_reward_model
from
.parallel_entry
import
parallel_pipeline
from
.parallel_entry
import
parallel_pipeline
from
.application_entry
import
eval
,
collect_demo_data
from
.application_entry
import
eval
,
collect_demo_data
from
.serial_entry_offline
import
serial_pipeline_offline
ding/entry/cli.py
浏览文件 @
f1bf66d0
...
@@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
...
@@ -52,7 +52,7 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@
click
.
option
(
@
click
.
option
(
'-m'
,
'-m'
,
'--mode'
,
'--mode'
,
type
=
click
.
Choice
([
'serial'
,
'serial_sqil'
,
'parallel'
,
'dist'
,
'eval'
]),
type
=
click
.
Choice
([
'serial'
,
'serial_
onpolicy'
,
'serial_
sqil'
,
'parallel'
,
'dist'
,
'eval'
]),
help
=
'serial-train or parallel-train or dist-train or eval'
help
=
'serial-train or parallel-train or dist-train or eval'
)
)
@
click
.
option
(
'-c'
,
'--config'
,
type
=
str
,
help
=
'Path to DRL experiment config'
)
@
click
.
option
(
'-c'
,
'--config'
,
type
=
str
,
help
=
'Path to DRL experiment config'
)
...
@@ -144,7 +144,12 @@ def cli(
...
@@ -144,7 +144,12 @@ def cli(
if
config
is
None
:
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline
(
config
,
seed
,
max_iterations
=
train_iter
)
serial_pipeline
(
config
,
seed
,
max_iterations
=
train_iter
)
if
mode
==
'serial_sqil'
:
elif
mode
==
'serial_onpolicy'
:
from
.serial_entry_onpolicy
import
serial_pipeline_onpolicy
if
config
is
None
:
config
=
get_predefined_config
(
env
,
policy
)
serial_pipeline_onpolicy
(
config
,
seed
,
max_iterations
=
train_iter
)
elif
mode
==
'serial_sqil'
:
if
config
==
'lunarlander_sqil_config.py'
or
'cartpole_sqil_config.py'
or
'pong_sqil_config.py'
\
if
config
==
'lunarlander_sqil_config.py'
or
'cartpole_sqil_config.py'
or
'pong_sqil_config.py'
\
or
'spaceinvaders_sqil_config.py'
or
'qbert_sqil_config.py'
:
or
'spaceinvaders_sqil_config.py'
or
'qbert_sqil_config.py'
:
from
.serial_entry_sqil
import
serial_pipeline_sqil
from
.serial_entry_sqil
import
serial_pipeline_sqil
...
...
ding/entry/serial_entry_onpolicy.py
0 → 100644
浏览文件 @
f1bf66d0
from
typing
import
Union
,
Optional
,
List
,
Any
,
Tuple
import
os
import
torch
import
logging
from
functools
import
partial
from
tensorboardX
import
SummaryWriter
from
ding.envs
import
get_vec_env_setting
,
create_env_manager
from
ding.worker
import
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
BaseSerialCommander
,
create_buffer
,
\
create_serial_collector
from
ding.config
import
read_config
,
compile_config
from
ding.policy
import
create_policy
,
PolicyFactory
from
ding.utils
import
set_pkg_seed
def
serial_pipeline_onpolicy
(
input_cfg
:
Union
[
str
,
Tuple
[
dict
,
dict
]],
seed
:
int
=
0
,
env_setting
:
Optional
[
List
[
Any
]]
=
None
,
model
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
max_iterations
:
Optional
[
int
]
=
int
(
1e10
),
)
->
'Policy'
:
# noqa
"""
Overview:
Serial pipeline entry for onpolicy algorithm(such as PPO).
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type.
\
``str`` type means config file path.
\
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements:
\
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop
\
when reaching this iteration.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if
isinstance
(
input_cfg
,
str
):
cfg
,
create_cfg
=
read_config
(
input_cfg
)
else
:
cfg
,
create_cfg
=
input_cfg
create_cfg
.
policy
.
type
=
create_cfg
.
policy
.
type
+
'_command'
env_fn
=
None
if
env_setting
is
None
else
env_setting
[
0
]
cfg
=
compile_config
(
cfg
,
seed
=
seed
,
env
=
env_fn
,
auto
=
True
,
create_cfg
=
create_cfg
,
save_cfg
=
True
)
# Create main components: env, policy
if
env_setting
is
None
:
env_fn
,
collector_env_cfg
,
evaluator_env_cfg
=
get_vec_env_setting
(
cfg
.
env
)
else
:
env_fn
,
collector_env_cfg
,
evaluator_env_cfg
=
env_setting
collector_env
=
create_env_manager
(
cfg
.
env
.
manager
,
[
partial
(
env_fn
,
cfg
=
c
)
for
c
in
collector_env_cfg
])
evaluator_env
=
create_env_manager
(
cfg
.
env
.
manager
,
[
partial
(
env_fn
,
cfg
=
c
)
for
c
in
evaluator_env_cfg
])
collector_env
.
seed
(
cfg
.
seed
)
evaluator_env
.
seed
(
cfg
.
seed
,
dynamic_seed
=
False
)
set_pkg_seed
(
cfg
.
seed
,
use_cuda
=
cfg
.
policy
.
cuda
)
policy
=
create_policy
(
cfg
.
policy
,
model
=
model
,
enable_field
=
[
'learn'
,
'collect'
,
'eval'
,
'command'
])
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger
=
SummaryWriter
(
os
.
path
.
join
(
'./{}/log/'
.
format
(
cfg
.
exp_name
),
'serial'
))
learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
collector
=
create_serial_collector
(
cfg
.
policy
.
collect
.
collector
,
env
=
collector_env
,
policy
=
policy
.
collect_mode
,
tb_logger
=
tb_logger
,
exp_name
=
cfg
.
exp_name
)
evaluator
=
BaseSerialEvaluator
(
cfg
.
policy
.
eval
.
evaluator
,
evaluator_env
,
policy
.
eval_mode
,
tb_logger
,
exp_name
=
cfg
.
exp_name
)
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner
.
call_hook
(
'before_run'
)
# Accumulate plenty of data at the beginning of training.
for
_
in
range
(
max_iterations
):
# Evaluate policy performance
if
evaluator
.
should_eval
(
learner
.
train_iter
):
stop
,
reward
=
evaluator
.
eval
(
learner
.
save_checkpoint
,
learner
.
train_iter
,
collector
.
envstep
)
if
stop
:
break
# Collect data by default config n_sample/n_episode
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
)
# Learn policy from collected data
learner
.
train
(
new_data
,
collector
.
envstep
)
# Learner's after_run hook.
learner
.
call_hook
(
'after_run'
)
return
policy
ding/model/template/__init__.py
浏览文件 @
f1bf66d0
...
@@ -10,3 +10,4 @@ from .atoc import ATOC
...
@@ -10,3 +10,4 @@ from .atoc import ATOC
from
.sqn
import
SQN
from
.sqn
import
SQN
from
.acer
import
ACER
from
.acer
import
ACER
from
.qtran
import
QTran
from
.qtran
import
QTran
from
.mappo
import
MAPPO
ding/model/template/mappo.py
0 → 100644
浏览文件 @
f1bf66d0
from
typing
import
Union
,
Dict
,
Optional
import
torch
import
torch.nn
as
nn
from
ding.utils
import
SequenceType
,
squeeze
,
MODEL_REGISTRY
from
..common
import
ReparameterizationHead
,
RegressionHead
,
DiscreteHead
,
MultiHead
,
\
FCEncoder
,
ConvEncoder
@
MODEL_REGISTRY
.
register
(
'mappo'
)
class
MAPPO
(
nn
.
Module
):
r
"""
Overview:
The MAPPO model.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode
=
[
'compute_actor'
,
'compute_critic'
,
'compute_actor_critic'
]
def
__init__
(
self
,
agent_obs_shape
:
Union
[
int
,
SequenceType
],
global_obs_shape
:
Union
[
int
,
SequenceType
],
action_shape
:
Union
[
int
,
SequenceType
],
agent_num
:
int
,
encoder_hidden_size_list
:
SequenceType
=
[
128
,
128
,
64
],
actor_head_hidden_size
:
int
=
64
,
actor_head_layer_num
:
int
=
2
,
critic_head_hidden_size
:
int
=
64
,
critic_head_layer_num
:
int
=
1
,
activation
:
Optional
[
nn
.
Module
]
=
nn
.
ReLU
(),
norm_type
:
Optional
[
str
]
=
None
,
)
->
None
:
r
"""
Overview:
Init the VAC Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- share_encoder (:obj:`bool`): Whether share encoder.
- continuous (:obj:`bool`): Whether collect continuously.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for actor's nn.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
- critic_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for critic's nn.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`
"""
super
(
MAPPO
,
self
).
__init__
()
agent_obs_shape
:
int
=
squeeze
(
agent_obs_shape
)
global_obs_shape
:
int
=
squeeze
(
global_obs_shape
)
action_shape
:
int
=
squeeze
(
action_shape
)
self
.
global_obs_shape
,
self
.
agent_obs_shape
,
self
.
action_shape
=
global_obs_shape
,
agent_obs_shape
,
action_shape
# Encoder Type
if
isinstance
(
agent_obs_shape
,
int
)
or
len
(
agent_obs_shape
)
==
1
:
encoder_cls
=
FCEncoder
elif
len
(
agent_obs_shape
)
==
3
:
encoder_cls
=
ConvEncoder
else
:
raise
RuntimeError
(
"not support obs_shape for pre-defined encoder: {}, please customize your own DQN"
.
format
(
agent_obs_shape
)
)
if
isinstance
(
global_obs_shape
,
int
)
or
len
(
global_obs_shape
)
==
1
:
global_encoder_cls
=
FCEncoder
elif
len
(
global_obs_shape
)
==
3
:
global_encoder_cls
=
ConvEncoder
else
:
raise
RuntimeError
(
"not support obs_shape for pre-defined encoder: {}, please customize your own DQN"
.
format
(
global_obs_shape
)
)
self
.
actor_encoder
=
encoder_cls
(
agent_obs_shape
,
encoder_hidden_size_list
,
activation
=
activation
,
norm_type
=
norm_type
)
self
.
critic_encoder
=
global_encoder_cls
(
global_obs_shape
,
encoder_hidden_size_list
,
activation
=
activation
,
norm_type
=
norm_type
)
# Head Type
self
.
critic_head
=
RegressionHead
(
critic_head_hidden_size
,
1
,
critic_head_layer_num
,
activation
=
activation
,
norm_type
=
norm_type
)
actor_head_cls
=
DiscreteHead
self
.
actor_head
=
actor_head_cls
(
actor_head_hidden_size
,
action_shape
,
actor_head_layer_num
,
activation
=
activation
,
norm_type
=
norm_type
)
# must use list, not nn.ModuleList
self
.
actor
=
[
self
.
actor_encoder
,
self
.
actor_head
]
self
.
critic
=
[
self
.
critic_encoder
,
self
.
critic_head
]
# for convenience of call some apis(such as: self.critic.parameters()), but may cause
# misunderstanding when print(self)
self
.
actor
=
nn
.
ModuleList
(
self
.
actor
)
self
.
critic
=
nn
.
ModuleList
(
self
.
critic
)
def
forward
(
self
,
inputs
:
Union
[
torch
.
Tensor
,
Dict
],
mode
:
str
)
->
Dict
:
r
"""
Overview:
Use encoded embedding tensor to predict output.
Parameter updates with VAC's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'`` or ``'compute_critic'``:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
Returns:
- outputs (:obj:`Dict`):
Run with encoder and head.
Forward with ``'compute_actor'``, Necessary Keys:
- logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.
Forward with ``'compute_critic'``, Necessary Keys:
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size``
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
- value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Actor Examples:
>>> model = VAC(64,128)
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == torch.Size([4, 128])
Critic Examples:
>>> model = VAC(64,64)
>>> inputs = torch.randn(4, 64)
>>> critic_outputs = model(inputs,'compute_critic')
>>> critic_outputs['value']
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>)
Actor-Critic Examples:
>>> model = VAC(64,64)
>>> inputs = torch.randn(4, 64)
>>> outputs = model(inputs,'compute_actor_critic')
>>> outputs['value']
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>)
>>> assert outputs['logit'].shape == torch.Size([4, 64])
"""
assert
mode
in
self
.
mode
,
"not support forward mode: {}/{}"
.
format
(
mode
,
self
.
mode
)
return
getattr
(
self
,
mode
)(
inputs
)
def
compute_actor
(
self
,
x
:
torch
.
Tensor
)
->
Dict
:
r
"""
Overview:
Execute parameter updates with ``'compute_actor'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
``hidden_size = actor_head_hidden_size``
Returns:
- outputs (:obj:`Dict`):
Run with encoder and head.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.
Shapes:
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
Examples:
>>> model = VAC(64,64)
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['action'].shape == torch.Size([4, 64])
"""
action_mask
=
x
[
'action_mask'
]
x
=
x
[
'agent_state'
]
x
=
self
.
actor_encoder
(
x
)
x
=
self
.
actor_head
(
x
)
logit
=
x
[
'logit'
]
logit
[
action_mask
==
0.0
]
=
-
99999999
return
{
'logit'
:
logit
,
'action_mask'
:
action_mask
}
def
compute_critic
(
self
,
x
:
Dict
)
->
Dict
:
r
"""
Overview:
Execute parameter updates with ``'compute_critic'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`Dict`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
``hidden_size = critic_head_hidden_size``
Returns:
- outputs (:obj:`Dict`):
Run with encoder and head.
Necessary Keys:
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Examples:
>>> model = VAC(64,64)
>>> inputs = torch.randn(4, 64)
>>> critic_outputs = model(inputs,'compute_critic')
>>> critic_outputs['value']
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>)
"""
x
=
self
.
critic_encoder
(
x
[
'global_state'
])
x
=
self
.
critic_head
(
x
)
return
{
'value'
:
x
[
'pred'
]}
def
compute_actor_critic
(
self
,
x
:
Dict
)
->
Dict
:
r
"""
Overview:
Execute parameter updates with ``'compute_actor_critic'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`): The encoded embedding tensor.
Returns:
- outputs (:obj:`Dict`):
Run with encoder and head.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
- value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Examples:
>>> model = VAC(64,64)
>>> inputs = torch.randn(4, 64)
>>> outputs = model(inputs,'compute_actor_critic')
>>> outputs['value']
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>)
>>> assert outputs['logit'].shape == torch.Size([4, 64])
.. note::
``compute_actor_critic`` interface aims to save computation when shares encoder.
Returning the combination dictionry.
"""
logit
=
self
.
compute_actor
(
x
)[
'logit'
]
value
=
self
.
compute_critic
(
x
)[
'value'
]
action_mask
=
x
[
'action_mask'
]
return
{
'logit'
:
logit
,
'value'
:
value
,
'action_mask'
:
x
[
'action_mask'
]}
ding/policy/ppo.py
浏览文件 @
f1bf66d0
...
@@ -35,6 +35,7 @@ class PPOPolicy(Policy):
...
@@ -35,6 +35,7 @@ class PPOPolicy(Policy):
priority_IS_weight
=
False
,
priority_IS_weight
=
False
,
recompute_adv
=
True
,
recompute_adv
=
True
,
continuous
=
True
,
continuous
=
True
,
multi_agent
=
False
,
learn
=
dict
(
learn
=
dict
(
# (bool) Whether to use multi gpu
# (bool) Whether to use multi gpu
multi_gpu
=
False
,
multi_gpu
=
False
,
...
@@ -124,7 +125,7 @@ class PPOPolicy(Policy):
...
@@ -124,7 +125,7 @@ class PPOPolicy(Policy):
self
.
_adv_norm
=
self
.
_cfg
.
learn
.
adv_norm
self
.
_adv_norm
=
self
.
_cfg
.
learn
.
adv_norm
self
.
_value_norm
=
self
.
_cfg
.
learn
.
value_norm
self
.
_value_norm
=
self
.
_cfg
.
learn
.
value_norm
if
self
.
_value_norm
:
if
self
.
_value_norm
:
self
.
_running_mean_std
=
RunningMeanStd
(
epsilon
=
1e-4
)
self
.
_running_mean_std
=
RunningMeanStd
(
epsilon
=
1e-4
,
device
=
self
.
_device
)
self
.
_gamma
=
self
.
_cfg
.
collect
.
discount_factor
self
.
_gamma
=
self
.
_cfg
.
collect
.
discount_factor
self
.
_gae_lambda
=
self
.
_cfg
.
collect
.
gae_lambda
self
.
_gae_lambda
=
self
.
_cfg
.
collect
.
gae_lambda
self
.
_recompute_adv
=
self
.
_cfg
.
recompute_adv
self
.
_recompute_adv
=
self
.
_cfg
.
recompute_adv
...
@@ -321,7 +322,7 @@ class PPOPolicy(Policy):
...
@@ -321,7 +322,7 @@ class PPOPolicy(Policy):
data
[
-
1
][
'done'
]
=
False
data
[
-
1
][
'done'
]
=
False
if
data
[
-
1
][
'done'
]:
if
data
[
-
1
][
'done'
]:
last_value
=
torch
.
zeros
(
1
)
last_value
=
torch
.
zeros
_like
(
data
[
-
1
][
'value'
]
)
else
:
else
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
last_value
=
self
.
_collect_model
.
forward
(
last_value
=
self
.
_collect_model
.
forward
(
...
@@ -382,7 +383,10 @@ class PPOPolicy(Policy):
...
@@ -382,7 +383,10 @@ class PPOPolicy(Policy):
return
{
i
:
d
for
i
,
d
in
zip
(
data_id
,
output
)}
return
{
i
:
d
for
i
,
d
in
zip
(
data_id
,
output
)}
def
default_model
(
self
)
->
Tuple
[
str
,
List
[
str
]]:
def
default_model
(
self
)
->
Tuple
[
str
,
List
[
str
]]:
return
'vac'
,
[
'ding.model.template.vac'
]
if
self
.
_cfg
.
multi_agent
:
return
'mappo'
,
[
'ding.model.template.mappo'
]
else
:
return
'vac'
,
[
'ding.model.template.vac'
]
def
_monitor_vars_learn
(
self
)
->
List
[
str
]:
def
_monitor_vars_learn
(
self
)
->
List
[
str
]:
variables
=
super
().
_monitor_vars_learn
()
+
[
variables
=
super
().
_monitor_vars_learn
()
+
[
...
...
ding/rl_utils/adder.py
浏览文件 @
f1bf66d0
...
@@ -65,7 +65,7 @@ class Adder(object):
...
@@ -65,7 +65,7 @@ class Adder(object):
extra advantage key 'adv'
extra advantage key 'adv'
"""
"""
if
done
:
if
done
:
last_value
=
torch
.
zeros
(
1
)
last_value
=
torch
.
zeros
_like
(
data
[
-
1
][
'value'
]
)
else
:
else
:
last_data
=
data
.
pop
()
last_data
=
data
.
pop
()
last_value
=
last_data
[
'value'
]
last_value
=
last_data
[
'value'
]
...
...
ding/rl_utils/gae.py
浏览文件 @
f1bf66d0
...
@@ -46,11 +46,13 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
...
@@ -46,11 +46,13 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
value
,
next_value
,
reward
,
done
=
data
value
,
next_value
,
reward
,
done
=
data
if
done
is
None
:
if
done
is
None
:
done
=
torch
.
zeros_like
(
reward
,
device
=
reward
.
device
)
done
=
torch
.
zeros_like
(
reward
,
device
=
reward
.
device
)
if
len
(
value
.
shape
)
==
len
(
reward
.
shape
)
+
1
:
# for some marl case: value(T, B, A), reward(T, B)
reward
=
reward
.
unsqueeze
(
-
1
)
done
=
done
.
unsqueeze
(
-
1
)
delta
=
reward
+
(
1
-
done
)
*
gamma
*
next_value
-
value
delta
=
reward
+
(
1
-
done
)
*
gamma
*
next_value
-
value
factor
=
gamma
*
lambda_
factor
=
gamma
*
lambda_
adv
=
torch
.
zeros_like
(
reward
,
device
=
reward
.
device
)
adv
=
torch
.
zeros_like
(
value
,
device
=
value
.
device
)
gae_item
=
0.
gae_item
=
torch
.
zeros_like
(
value
[
0
])
for
t
in
reversed
(
range
(
reward
.
shape
[
0
])):
for
t
in
reversed
(
range
(
reward
.
shape
[
0
])):
gae_item
=
delta
[
t
]
+
factor
*
gae_item
*
(
1
-
done
[
t
])
gae_item
=
delta
[
t
]
+
factor
*
gae_item
*
(
1
-
done
[
t
])
...
...
ding/rl_utils/ppo.py
浏览文件 @
f1bf66d0
...
@@ -92,9 +92,14 @@ def ppo_policy_error(data: namedtuple,
...
@@ -92,9 +92,14 @@ def ppo_policy_error(data: namedtuple,
dist_old
=
torch
.
distributions
.
categorical
.
Categorical
(
logits
=
logit_old
)
dist_old
=
torch
.
distributions
.
categorical
.
Categorical
(
logits
=
logit_old
)
logp_new
=
dist_new
.
log_prob
(
action
)
logp_new
=
dist_new
.
log_prob
(
action
)
logp_old
=
dist_old
.
log_prob
(
action
)
logp_old
=
dist_old
.
log_prob
(
action
)
entropy_loss
=
(
dist_new
.
entropy
()
*
weight
).
mean
()
dist_new_entropy
=
dist_new
.
entropy
()
if
dist_new_entropy
.
shape
!=
weight
.
shape
:
dist_new_entropy
=
dist_new
.
entropy
().
mean
(
dim
=
1
)
entropy_loss
=
(
dist_new_entropy
*
weight
).
mean
()
# policy_loss
# policy_loss
ratio
=
torch
.
exp
(
logp_new
-
logp_old
)
ratio
=
torch
.
exp
(
logp_new
-
logp_old
)
if
ratio
.
shape
!=
adv
.
shape
:
ratio
=
ratio
.
mean
(
dim
=
1
)
surr1
=
ratio
*
adv
surr1
=
ratio
*
adv
surr2
=
ratio
.
clamp
(
1
-
clip_ratio
,
1
+
clip_ratio
)
*
adv
surr2
=
ratio
.
clamp
(
1
-
clip_ratio
,
1
+
clip_ratio
)
*
adv
if
dual_clip
is
not
None
:
if
dual_clip
is
not
None
:
...
@@ -103,6 +108,7 @@ def ppo_policy_error(data: namedtuple,
...
@@ -103,6 +108,7 @@ def ppo_policy_error(data: namedtuple,
# only use dual_clip when adv < 0
# only use dual_clip when adv < 0
policy_loss
=
-
(
torch
.
where
(
adv
<
0
,
clip2
,
clip1
)
*
weight
).
mean
()
policy_loss
=
-
(
torch
.
where
(
adv
<
0
,
clip2
,
clip1
)
*
weight
).
mean
()
else
:
else
:
#policy_loss = (-torch.min(surr1, surr2) * weight).mean()
policy_loss
=
(
-
torch
.
min
(
surr1
,
surr2
)
*
weight
).
mean
()
policy_loss
=
(
-
torch
.
min
(
surr1
,
surr2
)
*
weight
).
mean
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
approx_kl
=
(
logp_old
-
logp_new
).
mean
().
item
()
approx_kl
=
(
logp_old
-
logp_new
).
mean
().
item
()
...
...
ding/rl_utils/tests/test_adder.py
浏览文件 @
f1bf66d0
...
@@ -18,6 +18,15 @@ class TestAdder:
...
@@ -18,6 +18,15 @@ class TestAdder:
'done'
:
False
'done'
:
False
}
}
def
get_transition_multi_agent
(
self
):
return
{
'value'
:
torch
.
randn
(
1
,
8
),
'reward'
:
torch
.
rand
(
1
,
1
),
'other'
:
np
.
random
.
randint
(
0
,
10
,
size
=
(
4
,
)),
'obs'
:
torch
.
randn
(
3
),
'done'
:
False
}
def
test_get_gae
(
self
):
def
test_get_gae
(
self
):
transitions
=
deque
([
self
.
get_transition
()
for
_
in
range
(
10
)])
transitions
=
deque
([
self
.
get_transition
()
for
_
in
range
(
10
)])
last_value
=
torch
.
randn
(
1
)
last_value
=
torch
.
randn
(
1
)
...
@@ -46,6 +55,39 @@ class TestAdder:
...
@@ -46,6 +55,39 @@ class TestAdder:
for
i
in
range
(
len
(
output
)):
for
i
in
range
(
len
(
output
)):
assert
output
[
i
][
'adv'
].
eq
(
output2
[
i
][
'adv'
])
assert
output
[
i
][
'adv'
].
eq
(
output2
[
i
][
'adv'
])
def
test_get_gae_multi_agent
(
self
):
transitions
=
deque
([
self
.
get_transition_multi_agent
()
for
_
in
range
(
10
)])
last_value
=
torch
.
randn
(
1
,
8
)
output
=
get_gae
(
transitions
,
last_value
,
gamma
=
0.99
,
gae_lambda
=
0.97
,
cuda
=
False
)
for
i
in
range
(
len
(
output
)):
o
=
output
[
i
]
assert
'adv'
in
o
.
keys
()
for
k
,
v
in
o
.
items
():
if
k
==
'adv'
:
assert
isinstance
(
v
,
torch
.
Tensor
)
assert
v
.
shape
==
(
1
,
8
,
)
else
:
if
k
==
'done'
:
assert
v
==
transitions
[
i
][
k
]
else
:
assert
(
v
==
transitions
[
i
][
k
]).
all
()
output1
=
get_gae_with_default_last_value
(
copy
.
deepcopy
(
transitions
),
True
,
gamma
=
0.99
,
gae_lambda
=
0.97
,
cuda
=
False
)
for
i
in
range
(
len
(
output
)):
for
j
in
range
(
output
[
i
][
'adv'
].
shape
[
1
]):
assert
output
[
i
][
'adv'
][
0
][
j
].
ne
(
output1
[
i
][
'adv'
][
0
][
j
])
data
=
copy
.
deepcopy
(
transitions
)
data
.
append
({
'value'
:
last_value
})
output2
=
get_gae_with_default_last_value
(
data
,
False
,
gamma
=
0.99
,
gae_lambda
=
0.97
,
cuda
=
False
)
for
i
in
range
(
len
(
output
)):
for
j
in
range
(
output
[
i
][
'adv'
].
shape
[
1
]):
assert
output
[
i
][
'adv'
][
0
][
j
].
eq
(
output2
[
i
][
'adv'
][
0
][
j
])
def
test_get_nstep_return_data
(
self
):
def
test_get_nstep_return_data
(
self
):
nstep
=
3
nstep
=
3
data
=
deque
([
self
.
get_transition
()
for
_
in
range
(
10
)])
data
=
deque
([
self
.
get_transition
()
for
_
in
range
(
10
)])
...
@@ -96,3 +138,7 @@ class TestAdder:
...
@@ -96,3 +138,7 @@ class TestAdder:
assert
output
[
-
1
][
'done'
][
-
1
]
is
True
assert
output
[
-
1
][
'done'
][
-
1
]
is
True
assert
output
[
-
1
][
'done'
][
0
]
is
False
assert
output
[
-
1
][
'done'
][
0
]
is
False
assert
id
(
output
[
-
1
][
'obs'
][
-
1
])
!=
id
(
output
[
-
1
][
'obs'
][
0
])
assert
id
(
output
[
-
1
][
'obs'
][
-
1
])
!=
id
(
output
[
-
1
][
'obs'
][
0
])
test
=
TestAdder
()
test
.
test_get_gae_multi_agent
()
ding/rl_utils/tests/test_gae.py
浏览文件 @
f1bf66d0
...
@@ -13,3 +13,14 @@ def test_gae():
...
@@ -13,3 +13,14 @@ def test_gae():
data
=
gae_data
(
value
,
next_value
,
reward
,
done
)
data
=
gae_data
(
value
,
next_value
,
reward
,
done
)
adv
=
gae
(
data
)
adv
=
gae
(
data
)
assert
adv
.
shape
==
(
T
,
B
)
assert
adv
.
shape
==
(
T
,
B
)
def
test_gae_multi_agent
():
T
,
B
,
A
=
32
,
4
,
8
value
=
torch
.
randn
(
T
,
B
,
A
)
next_value
=
torch
.
randn
(
T
,
B
,
A
)
reward
=
torch
.
randn
(
T
,
B
)
done
=
torch
.
zeros
(
T
,
B
)
data
=
gae_data
(
value
,
next_value
,
reward
,
done
)
adv
=
gae
(
data
)
assert
adv
.
shape
==
(
T
,
B
,
A
)
ding/utils/data/collate_fn.py
浏览文件 @
f1bf66d0
...
@@ -165,7 +165,10 @@ def diff_shape_collate(batch: Sequence) -> Union[torch.Tensor, Mapping, Sequence
...
@@ -165,7 +165,10 @@ def diff_shape_collate(batch: Sequence) -> Union[torch.Tensor, Mapping, Sequence
raise
TypeError
(
'not support element type: {}'
.
format
(
elem_type
))
raise
TypeError
(
'not support element type: {}'
.
format
(
elem_type
))
def
default_decollate
(
batch
:
Union
[
torch
.
Tensor
,
Sequence
,
Mapping
],
ignore
:
List
[
str
]
=
[
'prev_state'
])
->
List
[
Any
]:
def
default_decollate
(
batch
:
Union
[
torch
.
Tensor
,
Sequence
,
Mapping
],
ignore
:
List
[
str
]
=
[
'prev_state'
,
'prev_actor_state'
,
'prev_critic_state'
]
)
->
List
[
Any
]:
"""
"""
Overview:
Overview:
Drag out batch_size collated data's batch size to decollate it,
Drag out batch_size collated data's batch size to decollate it,
...
...
ding/utils/default_helper.py
浏览文件 @
f1bf66d0
...
@@ -3,7 +3,6 @@ import logging
...
@@ -3,7 +3,6 @@ import logging
import
random
import
random
from
typing
import
Union
,
Mapping
,
List
,
NamedTuple
,
Tuple
,
Callable
,
Optional
,
Any
from
typing
import
Union
,
Mapping
,
List
,
NamedTuple
,
Tuple
,
Callable
,
Optional
,
Any
from
functools
import
lru_cache
# in python3.9, we can change to cache
from
functools
import
lru_cache
# in python3.9, we can change to cache
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -414,11 +413,15 @@ def one_time_warning(warning_msg: str) -> None:
...
@@ -414,11 +413,15 @@ def one_time_warning(warning_msg: str) -> None:
def
split_data_generator
(
data
:
dict
,
split_size
:
int
,
shuffle
:
bool
=
True
)
->
dict
:
def
split_data_generator
(
data
:
dict
,
split_size
:
int
,
shuffle
:
bool
=
True
)
->
dict
:
assert
isinstance
(
data
,
dict
),
type
(
data
)
assert
isinstance
(
data
,
dict
),
type
(
data
)
length
=
[]
length
=
[]
for
v
in
data
.
value
s
():
for
k
,
v
in
data
.
item
s
():
if
v
is
None
:
if
v
is
None
:
continue
continue
elif
k
in
[
'prev_state'
,
'prev_actor_state'
,
'prev_critic_state'
]:
length
.
append
(
len
(
v
))
elif
isinstance
(
v
,
list
)
or
isinstance
(
v
,
tuple
):
elif
isinstance
(
v
,
list
)
or
isinstance
(
v
,
tuple
):
length
.
append
(
len
(
v
[
0
]))
length
.
append
(
len
(
v
[
0
]))
elif
isinstance
(
v
,
dict
):
length
.
append
(
len
(
v
[
list
(
v
.
keys
())[
0
]]))
else
:
else
:
length
.
append
(
len
(
v
))
length
.
append
(
len
(
v
))
assert
len
(
length
)
>
0
assert
len
(
length
)
>
0
...
@@ -436,8 +439,12 @@ def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> d
...
@@ -436,8 +439,12 @@ def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> d
for
k
in
data
.
keys
():
for
k
in
data
.
keys
():
if
data
[
k
]
is
None
:
if
data
[
k
]
is
None
:
batch
[
k
]
=
None
batch
[
k
]
=
None
elif
k
.
startswith
(
'prev_state'
):
batch
[
k
]
=
[
data
[
k
][
t
]
for
t
in
indices
[
i
:
i
+
split_size
]]
elif
isinstance
(
data
[
k
],
list
)
or
isinstance
(
data
[
k
],
tuple
):
elif
isinstance
(
data
[
k
],
list
)
or
isinstance
(
data
[
k
],
tuple
):
batch
[
k
]
=
[
t
[
indices
[
i
:
i
+
split_size
]]
for
t
in
data
[
k
]]
batch
[
k
]
=
[
t
[
indices
[
i
:
i
+
split_size
]]
for
t
in
data
[
k
]]
elif
isinstance
(
data
[
k
],
dict
):
batch
[
k
]
=
{
k1
:
v1
[
indices
[
i
:
i
+
split_size
]]
for
k1
,
v1
in
data
[
k
].
items
()}
else
:
else
:
batch
[
k
]
=
data
[
k
][
indices
[
i
:
i
+
split_size
]]
batch
[
k
]
=
data
[
k
][
indices
[
i
:
i
+
split_size
]]
yield
batch
yield
batch
...
@@ -453,7 +460,7 @@ class RunningMeanStd(object):
...
@@ -453,7 +460,7 @@ class RunningMeanStd(object):
- ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count``
- ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count``
"""
"""
def
__init__
(
self
,
epsilon
=
1e-4
,
shape
=
()):
def
__init__
(
self
,
epsilon
=
1e-4
,
shape
=
()
,
device
=
torch
.
device
(
'cpu'
)
):
"""
"""
Overview:
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate
\
Initialize ``self.`` See ``help(type(self))`` for accurate
\
...
@@ -466,6 +473,7 @@ class RunningMeanStd(object):
...
@@ -466,6 +473,7 @@ class RunningMeanStd(object):
"""
"""
self
.
_epsilon
=
epsilon
self
.
_epsilon
=
epsilon
self
.
_shape
=
shape
self
.
_shape
=
shape
self
.
_device
=
device
self
.
reset
()
self
.
reset
()
def
update
(
self
,
x
):
def
update
(
self
,
x
):
...
@@ -496,8 +504,11 @@ class RunningMeanStd(object):
...
@@ -496,8 +504,11 @@ class RunningMeanStd(object):
Overview:
Overview:
Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count``
Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count``
"""
"""
self
.
_mean
=
np
.
zeros
(
self
.
_shape
,
'float32'
)
if
len
(
self
.
_shape
)
>
0
:
self
.
_var
=
np
.
ones
(
self
.
_shape
,
'float32'
)
self
.
_mean
=
np
.
zeros
(
self
.
_shape
,
'float32'
)
self
.
_var
=
np
.
ones
(
self
.
_shape
,
'float32'
)
else
:
self
.
_mean
,
self
.
_var
=
0.
,
1.
self
.
_count
=
self
.
_epsilon
self
.
_count
=
self
.
_epsilon
@
property
@
property
...
@@ -506,7 +517,10 @@ class RunningMeanStd(object):
...
@@ -506,7 +517,10 @@ class RunningMeanStd(object):
Overview:
Overview:
Property ``mean`` gotten from ``self._mean``
Property ``mean`` gotten from ``self._mean``
"""
"""
return
self
.
_mean
if
np
.
isscalar
(
self
.
_mean
):
return
self
.
_mean
else
:
return
torch
.
FloatTensor
(
self
.
_mean
).
to
(
self
.
_device
)
@
property
@
property
def
std
(
self
)
->
np
.
ndarray
:
def
std
(
self
)
->
np
.
ndarray
:
...
@@ -514,7 +528,11 @@ class RunningMeanStd(object):
...
@@ -514,7 +528,11 @@ class RunningMeanStd(object):
Overview:
Overview:
Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon``
Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon``
"""
"""
return
np
.
sqrt
(
self
.
_var
+
1e-8
)
std
=
np
.
sqrt
(
self
.
_var
+
1e-8
)
if
np
.
isscalar
(
std
):
return
std
else
:
return
torch
.
FloatTensor
(
std
).
to
(
self
.
_device
)
@
staticmethod
@
staticmethod
def
new_shape
(
obs_shape
,
act_shape
,
rew_shape
):
def
new_shape
(
obs_shape
,
act_shape
,
rew_shape
):
...
...
dizoo/smac/config/smac_3s5z_mappo_config.py
0 → 100644
浏览文件 @
f1bf66d0
import
sys
from
copy
import
deepcopy
from
ding.entry
import
serial_pipeline
from
easydict
import
EasyDict
agent_num
=
8
collector_env_num
=
8
evaluator_env_num
=
8
special_global_state
=
True
main_config
=
dict
(
exp_name
=
'smac_3s5z_ppo'
,
env
=
dict
(
map_name
=
'3s5z'
,
difficulty
=
7
,
reward_only_positive
=
True
,
mirror_opponent
=
False
,
agent_num
=
agent_num
,
collector_env_num
=
collector_env_num
,
evaluator_env_num
=
evaluator_env_num
,
n_evaluator_episode
=
16
,
stop_value
=
0.99
,
death_mask
=
False
,
special_global_state
=
special_global_state
,
# save_replay_episodes = 1,
manager
=
dict
(
shared_memory
=
False
,
reset_timeout
=
6000
,
),
),
policy
=
dict
(
cuda
=
True
,
multi_agent
=
True
,
continuous
=
False
,
model
=
dict
(
# (int) agent_num: The number of the agent.
# For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
agent_num
=
agent_num
,
# (int) obs_shape: The shapeension of observation of each agent.
# For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
# (int) global_obs_shape: The shapeension of global observation.
# For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
agent_obs_shape
=
150
,
#global_obs_shape=216,
global_obs_shape
=
295
,
# (int) action_shape: The number of action which each agent can take.
# action_shape= the number of common action (6) + the number of enemies.
# For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
action_shape
=
14
,
# (List[int]) The size of hidden layer
# hidden_size_list=[64],
),
# used in state_num of hidden_state
learn
=
dict
(
# (bool) Whether to use multi gpu
multi_gpu
=
False
,
epoch_per_collect
=
5
,
batch_size
=
3200
,
learning_rate
=
5e-4
,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) The loss weight of value network, policy network weight is set to 1
value_weight
=
0.5
,
# (float) The loss weight of entropy regularization, policy network weight is set to 1
entropy_weight
=
0.01
,
# (float) PPO clip ratio, defaults to 0.2
clip_ratio
=
0.2
,
# (bool) Whether to use advantage norm in a whole training batch
adv_norm
=
False
,
value_norm
=
True
,
ppo_param_init
=
True
,
grad_clip_type
=
'clip_norm'
,
grad_clip_value
=
10
,
ignore_done
=
False
,
),
on_policy
=
True
,
collect
=
dict
(
env_num
=
collector_env_num
,
n_sample
=
3200
),
eval
=
dict
(
env_num
=
evaluator_env_num
),
),
)
main_config
=
EasyDict
(
main_config
)
create_config
=
dict
(
env
=
dict
(
type
=
'smac'
,
import_names
=
[
'dizoo.smac.envs.smac_env'
],
),
env_manager
=
dict
(
type
=
'base'
),
policy
=
dict
(
type
=
'ppo'
),
)
create_config
=
EasyDict
(
create_config
)
dizoo/smac/config/smac_5m6m_mappo_config.py
0 → 100644
浏览文件 @
f1bf66d0
import
sys
from
copy
import
deepcopy
from
ding.entry
import
serial_pipeline
from
easydict
import
EasyDict
agent_num
=
5
collector_env_num
=
8
evaluator_env_num
=
8
special_global_state
=
True
,
main_config
=
dict
(
exp_name
=
'smac_5m6m_ppo'
,
env
=
dict
(
map_name
=
'5m_vs_6m'
,
difficulty
=
7
,
reward_only_positive
=
True
,
mirror_opponent
=
False
,
agent_num
=
agent_num
,
collector_env_num
=
collector_env_num
,
evaluator_env_num
=
evaluator_env_num
,
n_evaluator_episode
=
16
,
stop_value
=
0.99
,
death_mask
=
True
,
special_global_state
=
special_global_state
,
manager
=
dict
(
shared_memory
=
False
,
reset_timeout
=
6000
,
),
),
policy
=
dict
(
cuda
=
True
,
multi_agent
=
True
,
continuous
=
False
,
model
=
dict
(
# (int) agent_num: The number of the agent.
# For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
agent_num
=
agent_num
,
# (int) obs_shape: The shapeension of observation of each agent.
# For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
# (int) global_obs_shape: The shapeension of global observation.
# For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
agent_obs_shape
=
72
,
#global_obs_shape=216,
global_obs_shape
=
152
,
# (int) action_shape: The number of action which each agent can take.
# action_shape= the number of common action (6) + the number of enemies.
# For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
action_shape
=
12
,
# (List[int]) The size of hidden layer
# hidden_size_list=[64],
),
# used in state_num of hidden_state
learn
=
dict
(
# (bool) Whether to use multi gpu
multi_gpu
=
False
,
epoch_per_collect
=
10
,
batch_size
=
3200
,
learning_rate
=
5e-4
,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) The loss weight of value network, policy network weight is set to 1
value_weight
=
0.5
,
# (float) The loss weight of entropy regularization, policy network weight is set to 1
entropy_weight
=
0.01
,
# (float) PPO clip ratio, defaults to 0.2
clip_ratio
=
0.05
,
# (bool) Whether to use advantage norm in a whole training batch
adv_norm
=
False
,
value_norm
=
True
,
ppo_param_init
=
True
,
grad_clip_type
=
'clip_norm'
,
grad_clip_value
=
10
,
ignore_done
=
False
,
),
on_policy
=
True
,
collect
=
dict
(
env_num
=
collector_env_num
,
n_sample
=
3200
),
eval
=
dict
(
env_num
=
evaluator_env_num
),
),
)
main_config
=
EasyDict
(
main_config
)
create_config
=
dict
(
env
=
dict
(
type
=
'smac'
,
import_names
=
[
'dizoo.smac.envs.smac_env'
],
),
env_manager
=
dict
(
type
=
'base'
),
policy
=
dict
(
type
=
'ppo'
),
)
create_config
=
EasyDict
(
create_config
)
dizoo/smac/config/smac_MMM2_mappo_config.py
0 → 100644
浏览文件 @
f1bf66d0
import
sys
from
copy
import
deepcopy
from
ding.entry
import
serial_pipeline
from
easydict
import
EasyDict
agent_num
=
10
collector_env_num
=
8
evaluator_env_num
=
8
special_global_state
=
True
main_config
=
dict
(
exp_name
=
'smac_MMM2_ppo'
,
env
=
dict
(
map_name
=
'MMM2'
,
difficulty
=
7
,
reward_only_positive
=
True
,
mirror_opponent
=
False
,
agent_num
=
agent_num
,
collector_env_num
=
collector_env_num
,
evaluator_env_num
=
evaluator_env_num
,
n_evaluator_episode
=
16
,
stop_value
=
0.99
,
death_mask
=
True
,
special_global_state
=
special_global_state
,
manager
=
dict
(
shared_memory
=
False
,
reset_timeout
=
6000
,
),
),
policy
=
dict
(
cuda
=
True
,
multi_agent
=
True
,
continuous
=
False
,
model
=
dict
(
# (int) agent_num: The number of the agent.
# For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
agent_num
=
agent_num
,
# (int) obs_shape: The shapeension of observation of each agent.
# For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
# (int) global_obs_shape: The shapeension of global observation.
# For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
agent_obs_shape
=
204
,
global_obs_shape
=
431
,
# (int) action_shape: The number of action which each agent can take.
# action_shape= the number of common action (6) + the number of enemies.
# For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
action_shape
=
18
,
# (List[int]) The size of hidden layer
# hidden_size_list=[64],
),
# used in state_num of hidden_state
learn
=
dict
(
# (bool) Whether to use multi gpu
multi_gpu
=
False
,
epoch_per_collect
=
5
,
batch_size
=
1600
,
learning_rate
=
5e-4
,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) The loss weight of value network, policy network weight is set to 1
value_weight
=
0.5
,
# (float) The loss weight of entropy regularization, policy network weight is set to 1
entropy_weight
=
0.01
,
# (float) PPO clip ratio, defaults to 0.2
clip_ratio
=
0.2
,
# (bool) Whether to use advantage norm in a whole training batch
adv_norm
=
False
,
value_norm
=
True
,
ppo_param_init
=
True
,
grad_clip_type
=
'clip_norm'
,
grad_clip_value
=
10
,
ignore_done
=
False
,
),
on_policy
=
True
,
collect
=
dict
(
env_num
=
collector_env_num
,
n_sample
=
3200
),
eval
=
dict
(
env_num
=
evaluator_env_num
),
),
)
main_config
=
EasyDict
(
main_config
)
create_config
=
dict
(
env
=
dict
(
type
=
'smac'
,
import_names
=
[
'dizoo.smac.envs.smac_env'
],
),
env_manager
=
dict
(
type
=
'base'
),
policy
=
dict
(
type
=
'ppo'
),
)
create_config
=
EasyDict
(
create_config
)
dizoo/smac/config/smac_MMM_mappo_config.py
0 → 100644
浏览文件 @
f1bf66d0
import
sys
from
copy
import
deepcopy
from
ding.entry
import
serial_pipeline
from
easydict
import
EasyDict
agent_num
=
10
collector_env_num
=
8
evaluator_env_num
=
8
special_global_state
=
True
,
main_config
=
dict
(
exp_name
=
'smac_MMM_ppo'
,
env
=
dict
(
map_name
=
'MMM'
,
difficulty
=
7
,
reward_only_positive
=
True
,
mirror_opponent
=
False
,
agent_num
=
agent_num
,
collector_env_num
=
collector_env_num
,
evaluator_env_num
=
evaluator_env_num
,
n_evaluator_episode
=
16
,
stop_value
=
0.99
,
death_mask
=
False
,
special_global_state
=
special_global_state
,
manager
=
dict
(
shared_memory
=
False
,
reset_timeout
=
6000
,
),
),
policy
=
dict
(
cuda
=
True
,
multi_agent
=
True
,
continuous
=
False
,
model
=
dict
(
# (int) agent_num: The number of the agent.
# For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
agent_num
=
agent_num
,
# (int) obs_shape: The shapeension of observation of each agent.
# For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
# (int) global_obs_shape: The shapeension of global observation.
# For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
agent_obs_shape
=
186
,
#global_obs_shape=216,
global_obs_shape
=
389
,
# (int) action_shape: The number of action which each agent can take.
# action_shape= the number of common action (6) + the number of enemies.
# For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
action_shape
=
16
,
# (List[int]) The size of hidden layer
# hidden_size_list=[64],
),
# used in state_num of hidden_state
learn
=
dict
(
# (bool) Whether to use multi gpu
multi_gpu
=
False
,
epoch_per_collect
=
5
,
batch_size
=
320
,
learning_rate
=
5e-4
,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) The loss weight of value network, policy network weight is set to 1
value_weight
=
0.5
,
# (float) The loss weight of entropy regularization, policy network weight is set to 1
entropy_weight
=
0.01
,
# (float) PPO clip ratio, defaults to 0.2
clip_ratio
=
0.2
,
# (bool) Whether to use advantage norm in a whole training batch
adv_norm
=
False
,
value_norm
=
True
,
ppo_param_init
=
True
,
grad_clip_type
=
'clip_norm'
,
grad_clip_value
=
10
,
ignore_done
=
False
,
),
on_policy
=
True
,
collect
=
dict
(
env_num
=
collector_env_num
,
n_sample
=
3200
),
eval
=
dict
(
env_num
=
evaluator_env_num
),
),
)
main_config
=
EasyDict
(
main_config
)
create_config
=
dict
(
env
=
dict
(
type
=
'smac'
,
import_names
=
[
'dizoo.smac.envs.smac_env'
],
),
env_manager
=
dict
(
type
=
'base'
),
policy
=
dict
(
type
=
'ppo'
),
)
create_config
=
EasyDict
(
create_config
)
dizoo/smac/envs/smac_env.py
浏览文件 @
f1bf66d0
...
@@ -4,6 +4,7 @@ from collections import namedtuple
...
@@ -4,6 +4,7 @@ from collections import namedtuple
from
operator
import
attrgetter
from
operator
import
attrgetter
import
numpy
as
np
import
numpy
as
np
import
math
from
easydict
import
EasyDict
from
easydict
import
EasyDict
import
pysc2.env.sc2_env
as
sc2_env
import
pysc2.env.sc2_env
as
sc2_env
from
pysc2.env.sc2_env
import
SC2Env
from
pysc2.env.sc2_env
import
SC2Env
...
@@ -65,6 +66,12 @@ class SMACEnv(SC2Env, BaseEnv):
...
@@ -65,6 +66,12 @@ class SMACEnv(SC2Env, BaseEnv):
obs_alone
=
False
,
obs_alone
=
False
,
game_steps_per_episode
=
None
,
game_steps_per_episode
=
None
,
reward_only_positive
=
True
,
reward_only_positive
=
True
,
death_mask
=
False
,
special_global_state
=
False
,
# add map's center location ponit or not
add_center_xy
=
True
,
# add agent's id information or not in special global state
state_agent_id
=
True
,
)
)
def
__init__
(
def
__init__
(
...
@@ -126,6 +133,11 @@ class SMACEnv(SC2Env, BaseEnv):
...
@@ -126,6 +133,11 @@ class SMACEnv(SC2Env, BaseEnv):
self
.
hydralisk_id
=
self
.
zergling_id
=
self
.
baneling_id
=
0
self
.
hydralisk_id
=
self
.
zergling_id
=
self
.
baneling_id
=
0
self
.
stalker_id
=
self
.
colossus_id
=
self
.
zealot_id
=
0
self
.
stalker_id
=
self
.
colossus_id
=
self
.
zealot_id
=
0
self
.
add_center_xy
=
cfg
.
add_center_xy
self
.
state_agent_id
=
cfg
.
state_agent_id
self
.
death_mask
=
cfg
.
death_mask
self
.
special_global_state
=
cfg
.
special_global_state
# reward
# reward
self
.
reward_death_value
=
cfg
.
reward_death_value
self
.
reward_death_value
=
cfg
.
reward_death_value
self
.
reward_win
=
cfg
.
reward_win
self
.
reward_win
=
cfg
.
reward_win
...
@@ -351,11 +363,18 @@ class SMACEnv(SC2Env, BaseEnv):
...
@@ -351,11 +363,18 @@ class SMACEnv(SC2Env, BaseEnv):
'action_mask'
:
self
.
get_avail_actions
()
'action_mask'
:
self
.
get_avail_actions
()
}
}
else
:
else
:
return
{
if
self
.
special_global_state
:
'agent_state'
:
self
.
get_obs
(),
return
{
'global_state'
:
self
.
get_state
(),
'agent_state'
:
self
.
get_obs
(),
'action_mask'
:
self
.
get_avail_actions
()
'global_state'
:
self
.
get_global_special_state
(),
}
'action_mask'
:
self
.
get_avail_actions
(),
}
else
:
return
{
'agent_state'
:
self
.
get_obs
(),
'global_state'
:
self
.
get_state
(),
'action_mask'
:
self
.
get_avail_actions
(),
}
return
{
return
{
'agent_state'
:
{
'agent_state'
:
{
...
@@ -451,11 +470,18 @@ class SMACEnv(SC2Env, BaseEnv):
...
@@ -451,11 +470,18 @@ class SMACEnv(SC2Env, BaseEnv):
'action_mask'
:
self
.
get_avail_actions
()
'action_mask'
:
self
.
get_avail_actions
()
}
}
else
:
else
:
obs
=
{
if
self
.
special_global_state
:
'agent_state'
:
self
.
get_obs
(),
obs
=
{
'global_state'
:
self
.
get_state
(),
'agent_state'
:
self
.
get_obs
(),
'action_mask'
:
self
.
get_avail_actions
()
'global_state'
:
self
.
get_global_special_state
(),
}
'action_mask'
:
self
.
get_avail_actions
(),
}
else
:
obs
=
{
'agent_state'
:
self
.
get_obs
(),
'global_state'
:
self
.
get_state
(),
'action_mask'
:
self
.
get_avail_actions
(),
}
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -1174,6 +1200,246 @@ class SMACEnv(SC2Env, BaseEnv):
...
@@ -1174,6 +1200,246 @@ class SMACEnv(SC2Env, BaseEnv):
state
=
self
.
_flatten_state
(
state
)
state
=
self
.
_flatten_state
(
state
)
return
np
.
array
(
state
).
astype
(
np
.
float32
)
return
np
.
array
(
state
).
astype
(
np
.
float32
)
def
get_global_special_state
(
self
,
is_opponent
=
False
):
"""Returns all agent observations in a list.
NOTE: Agents should have access only to their local observations
during decentralised execution.
"""
agents_obs_list
=
[
self
.
get_state_agent
(
i
,
is_opponent
)
for
i
in
range
(
self
.
n_agents
)]
return
np
.
array
(
agents_obs_list
).
astype
(
np
.
float32
)
def
get_global_special_state_size
(
self
,
is_opponent
=
False
):
enemy_feats_dim
=
self
.
get_state_enemy_feats_size
()
ally_feats_dim
=
self
.
get_state_ally_feats_size
()
own_feats_dim
=
self
.
get_state_own_feats_size
()
size
=
enemy_feats_dim
+
ally_feats_dim
+
own_feats_dim
+
self
.
n_agents
if
self
.
state_timestep_number
:
size
+=
1
return
size
def
get_state_agent
(
self
,
agent_id
,
is_opponent
=
False
):
"""Returns observation for agent_id. The observation is composed of:
- agent movement features (where it can move to, height information and pathing grid)
- enemy features (available_to_attack, health, relative_x, relative_y, shield, unit_type)
- ally features (visible, distance, relative_x, relative_y, shield, unit_type)
- agent unit features (health, shield, unit_type)
All of this information is flattened and concatenated into a list,
in the aforementioned order. To know the sizes of each of the
features inside the final list of features, take a look at the
functions ``get_obs_move_feats_size()``,
``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and
``get_obs_own_feats_size()``.
The size of the observation vector may vary, depending on the
environment configuration and type of units present in the map.
For instance, non-Protoss units will not have shields, movement
features may or may not include terrain height and pathing grid,
unit_type is not included if there is only one type of unit in the
map etc.).
NOTE: Agents should have access only to their local observations
during decentralised execution.
"""
if
self
.
obs_instead_of_state
:
obs_concat
=
np
.
concatenate
(
self
.
get_obs
(),
axis
=
0
).
astype
(
np
.
float32
)
return
obs_concat
unit
=
self
.
get_unit_by_id
(
agent_id
)
enemy_feats_dim
=
self
.
get_state_enemy_feats_size
()
ally_feats_dim
=
self
.
get_state_ally_feats_size
()
own_feats_dim
=
self
.
get_state_own_feats_size
()
enemy_feats
=
np
.
zeros
(
enemy_feats_dim
,
dtype
=
np
.
float32
)
ally_feats
=
np
.
zeros
(
ally_feats_dim
,
dtype
=
np
.
float32
)
own_feats
=
np
.
zeros
(
own_feats_dim
,
dtype
=
np
.
float32
)
agent_id_feats
=
np
.
zeros
(
self
.
n_agents
,
dtype
=
np
.
float32
)
center_x
=
self
.
map_x
/
2
center_y
=
self
.
map_y
/
2
if
(
self
.
death_mask
and
unit
.
health
>
0
)
or
(
not
self
.
death_mask
):
# otherwise dead, return all zeros
x
=
unit
.
pos
.
x
y
=
unit
.
pos
.
y
sight_range
=
self
.
unit_sight_range
(
agent_id
)
last_action
=
self
.
action_helper
.
get_last_action
(
is_opponent
)
# Movement features
avail_actions
=
self
.
get_avail_agent_actions
(
agent_id
)
# Enemy features
for
e_id
,
e_unit
in
self
.
enemies
.
items
():
e_x
=
e_unit
.
pos
.
x
e_y
=
e_unit
.
pos
.
y
dist
=
self
.
distance
(
x
,
y
,
e_x
,
e_y
)
if
e_unit
.
health
>
0
:
# visible and alive
# Sight range > shoot range
if
unit
.
health
>
0
:
enemy_feats
[
e_id
,
0
]
=
avail_actions
[
self
.
action_helper
.
n_actions_no_attack
+
e_id
]
# available
enemy_feats
[
e_id
,
1
]
=
dist
/
sight_range
# distance
enemy_feats
[
e_id
,
2
]
=
(
e_x
-
x
)
/
sight_range
# relative X
enemy_feats
[
e_id
,
3
]
=
(
e_y
-
y
)
/
sight_range
# relative Y
if
dist
<
sight_range
:
enemy_feats
[
e_id
,
4
]
=
1
# visible
ind
=
5
if
self
.
obs_all_health
:
enemy_feats
[
e_id
,
ind
]
=
(
e_unit
.
health
/
e_unit
.
health_max
)
# health
ind
+=
1
if
self
.
shield_bits_enemy
>
0
:
max_shield
=
self
.
unit_max_shield
(
e_unit
)
enemy_feats
[
e_id
,
ind
]
=
(
e_unit
.
shield
/
max_shield
)
# shield
ind
+=
1
if
self
.
unit_type_bits
>
0
:
type_id
=
self
.
get_unit_type_id
(
e_unit
,
False
)
enemy_feats
[
e_id
,
ind
+
type_id
]
=
1
# unit type
ind
+=
self
.
unit_type_bits
if
self
.
add_center_xy
:
enemy_feats
[
e_id
,
ind
]
=
(
e_x
-
center_x
)
/
self
.
max_distance_x
# center X
enemy_feats
[
e_id
,
ind
+
1
]
=
(
e_y
-
center_y
)
/
self
.
max_distance_y
# center Y
# Ally features
al_ids
=
[
al_id
for
al_id
in
range
(
self
.
n_agents
)
if
al_id
!=
agent_id
]
for
i
,
al_id
in
enumerate
(
al_ids
):
al_unit
=
self
.
get_unit_by_id
(
al_id
)
al_x
=
al_unit
.
pos
.
x
al_y
=
al_unit
.
pos
.
y
dist
=
self
.
distance
(
x
,
y
,
al_x
,
al_y
)
max_cd
=
self
.
unit_max_cooldown
(
al_unit
)
if
al_unit
.
health
>
0
:
# visible and alive
if
unit
.
health
>
0
:
if
dist
<
sight_range
:
ally_feats
[
i
,
0
]
=
1
# visible
ally_feats
[
i
,
1
]
=
dist
/
sight_range
# distance
ally_feats
[
i
,
2
]
=
(
al_x
-
x
)
/
sight_range
# relative X
ally_feats
[
i
,
3
]
=
(
al_y
-
y
)
/
sight_range
# relative Y
if
(
self
.
map_type
==
"MMM"
and
al_unit
.
unit_type
==
self
.
medivac_id
):
ally_feats
[
i
,
4
]
=
al_unit
.
energy
/
max_cd
# energy
else
:
ally_feats
[
i
,
4
]
=
(
al_unit
.
weapon_cooldown
/
max_cd
)
# cooldown
ind
=
5
if
self
.
obs_all_health
:
ally_feats
[
i
,
ind
]
=
(
al_unit
.
health
/
al_unit
.
health_max
)
# health
ind
+=
1
if
self
.
shield_bits_ally
>
0
:
max_shield
=
self
.
unit_max_shield
(
al_unit
)
ally_feats
[
i
,
ind
]
=
(
al_unit
.
shield
/
max_shield
)
# shield
ind
+=
1
if
self
.
add_center_xy
:
ally_feats
[
i
,
ind
]
=
(
al_x
-
center_x
)
/
self
.
max_distance_x
# center X
ally_feats
[
i
,
ind
+
1
]
=
(
al_y
-
center_y
)
/
self
.
max_distance_y
# center Y
ind
+=
2
if
self
.
unit_type_bits
>
0
:
type_id
=
self
.
get_unit_type_id
(
al_unit
,
True
)
ally_feats
[
i
,
ind
+
type_id
]
=
1
ind
+=
self
.
unit_type_bits
if
self
.
state_last_action
:
ally_feats
[
i
,
ind
:]
=
last_action
[
al_id
]
# Own features
ind
=
0
own_feats
[
0
]
=
1
# visible
own_feats
[
1
]
=
0
# distance
own_feats
[
2
]
=
0
# X
own_feats
[
3
]
=
0
# Y
ind
=
4
if
self
.
obs_own_health
:
own_feats
[
ind
]
=
unit
.
health
/
unit
.
health_max
ind
+=
1
if
self
.
shield_bits_ally
>
0
:
max_shield
=
self
.
unit_max_shield
(
unit
)
own_feats
[
ind
]
=
unit
.
shield
/
max_shield
ind
+=
1
if
self
.
add_center_xy
:
own_feats
[
ind
]
=
(
x
-
center_x
)
/
self
.
max_distance_x
# center X
own_feats
[
ind
+
1
]
=
(
y
-
center_y
)
/
self
.
max_distance_y
# center Y
ind
+=
2
if
self
.
unit_type_bits
>
0
:
type_id
=
self
.
get_unit_type_id
(
unit
,
True
)
own_feats
[
ind
+
type_id
]
=
1
ind
+=
self
.
unit_type_bits
if
self
.
state_last_action
:
own_feats
[
ind
:]
=
last_action
[
agent_id
]
state
=
np
.
concatenate
((
ally_feats
.
flatten
(),
enemy_feats
.
flatten
(),
own_feats
.
flatten
()))
# Agent id features
if
self
.
state_agent_id
:
agent_id_feats
[
agent_id
]
=
1.
state
=
np
.
append
(
state
,
agent_id_feats
.
flatten
())
if
self
.
state_timestep_number
:
state
=
np
.
append
(
state
,
self
.
_episode_steps
/
self
.
episode_limit
)
return
state
def
get_state_enemy_feats_size
(
self
):
""" Returns the dimensions of the matrix containing enemy features.
Size is n_enemies x n_features.
"""
nf_en
=
5
+
self
.
unit_type_bits
if
self
.
obs_all_health
:
nf_en
+=
1
+
self
.
shield_bits_enemy
if
self
.
add_center_xy
:
nf_en
+=
2
return
self
.
n_enemies
,
nf_en
def
get_state_ally_feats_size
(
self
):
"""Returns the dimensions of the matrix containing ally features.
Size is n_allies x n_features.
"""
nf_al
=
5
+
self
.
unit_type_bits
if
self
.
obs_all_health
:
nf_al
+=
1
+
self
.
shield_bits_ally
if
self
.
state_last_action
:
nf_al
+=
self
.
n_actions
if
self
.
add_center_xy
:
nf_al
+=
2
return
self
.
n_agents
-
1
,
nf_al
def
get_state_own_feats_size
(
self
):
"""Returns the size of the vector containing the agents' own features.
"""
own_feats
=
4
+
self
.
unit_type_bits
if
self
.
obs_own_health
:
own_feats
+=
1
+
self
.
shield_bits_ally
if
self
.
state_last_action
:
own_feats
+=
self
.
n_actions
if
self
.
add_center_xy
:
own_feats
+=
2
return
own_feats
@
staticmethod
def
distance
(
x1
,
y1
,
x2
,
y2
):
"""Distance between two points."""
return
math
.
hypot
(
x2
-
x1
,
y2
-
y1
)
def
unit_max_cooldown
(
self
,
unit
,
is_opponent
=
False
):
def
unit_max_cooldown
(
self
,
unit
,
is_opponent
=
False
):
"""Returns the maximal cooldown for a unit."""
"""Returns the maximal cooldown for a unit."""
if
is_opponent
:
if
is_opponent
:
...
@@ -1329,14 +1595,24 @@ class SMACEnv(SC2Env, BaseEnv):
...
@@ -1329,14 +1595,24 @@ class SMACEnv(SC2Env, BaseEnv):
None
,
None
,
)
)
else
:
else
:
obs_space
=
T
(
if
self
.
special_global_state
:
{
obs_space
=
T
(
'agent_state'
:
(
agent_num
,
self
.
get_obs_size
(
is_opponent
)),
{
'global_state'
:
(
self
.
get_state_size
(
is_opponent
),
),
'agent_state'
:
(
agent_num
,
self
.
get_obs_size
(
is_opponent
)),
'action_mask'
:
(
agent_num
,
*
self
.
action_helper
.
info
().
shape
),
'global_state'
:
(
agent_num
,
self
.
get_global_special_state_size
(
is_opponent
)),
},
'action_mask'
:
(
agent_num
,
*
self
.
action_helper
.
info
().
shape
),
None
,
},
)
None
,
)
else
:
obs_space
=
T
(
{
'agent_state'
:
(
agent_num
,
self
.
get_obs_size
(
is_opponent
)),
'global_state'
:
(
self
.
get_state_size
(
is_opponent
),
),
'action_mask'
:
(
agent_num
,
*
self
.
action_helper
.
info
().
shape
),
},
None
,
)
return
self
.
SMACEnvInfo
(
return
self
.
SMACEnvInfo
(
agent_num
=
agent_num
,
agent_num
=
agent_num
,
obs_space
=
obs_space
,
obs_space
=
obs_space
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录