Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
e30a3d3c
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
60
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e30a3d3c
编写于
7月 14, 2021
作者:
Z
zhangyinmin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify the gae recomputation; add/update ppo config/entry.
上级
8fffde51
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
96 addition
and
47 deletion
+96
-47
ding/model/common/head.py
ding/model/common/head.py
+1
-1
ding/policy/ppo.py
ding/policy/ppo.py
+12
-21
ding/rl_utils/adder.py
ding/rl_utils/adder.py
+3
-2
ding/rl_utils/gae.py
ding/rl_utils/gae.py
+9
-10
dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
+10
-10
dizoo/mujoco/config/hopper_ppo_default_config.py
dizoo/mujoco/config/hopper_ppo_default_config.py
+3
-2
dizoo/mujoco/entry/__init__.py
dizoo/mujoco/entry/__init__.py
+0
-0
dizoo/mujoco/entry/mujoco_ppo_main.py
dizoo/mujoco/entry/mujoco_ppo_main.py
+58
-0
dizoo/mujoco/envs/mujoco_wrappers.py
dizoo/mujoco/envs/mujoco_wrappers.py
+0
-1
未找到文件。
ding/model/common/head.py
浏览文件 @
e30a3d3c
...
@@ -600,7 +600,7 @@ class RegressionHead(nn.Module):
...
@@ -600,7 +600,7 @@ class RegressionHead(nn.Module):
class
ReparameterizationHead
(
nn
.
Module
):
class
ReparameterizationHead
(
nn
.
Module
):
default_sigma_type
=
[
'fixed'
,
'independent'
,
'conditioned'
]
default_sigma_type
=
[
'fixed'
,
'independent'
,
'conditioned'
]
default_bound_type
=
[
'tanh'
]
default_bound_type
=
[
'tanh'
,
None
]
def
__init__
(
def
__init__
(
self
,
self
,
...
...
ding/policy/ppo.py
浏览文件 @
e30a3d3c
...
@@ -33,7 +33,7 @@ class PPOPolicy(Policy):
...
@@ -33,7 +33,7 @@ class PPOPolicy(Policy):
priority
=
False
,
priority
=
False
,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight
=
False
,
priority_IS_weight
=
False
,
recompute_adv
=
Fals
e
,
recompute_adv
=
Tru
e
,
continuous
=
True
,
continuous
=
True
,
learn
=
dict
(
learn
=
dict
(
# (bool) Whether to use multi gpu
# (bool) Whether to use multi gpu
...
@@ -90,10 +90,10 @@ class PPOPolicy(Policy):
...
@@ -90,10 +90,10 @@ class PPOPolicy(Policy):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
torch
.
nn
.
init
.
orthogonal_
(
m
.
weight
)
torch
.
nn
.
init
.
orthogonal_
(
m
.
weight
)
torch
.
nn
.
init
.
zeros_
(
m
.
bias
)
torch
.
nn
.
init
.
zeros_
(
m
.
bias
)
# self._model._actor[-1].weight.data.mul_(0.1)
if
self
.
_continuous
:
if
self
.
_continuous
:
# init log sigma
# init log sigma
# torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5)
if
hasattr
(
self
.
_model
.
actor_head
,
'log_sigma_param'
):
torch
.
nn
.
init
.
constant_
(
self
.
_model
.
actor_head
.
log_sigma_param
,
-
0.5
)
for
m
in
list
(
self
.
_model
.
critic
.
modules
())
+
list
(
self
.
_model
.
actor
.
modules
()):
for
m
in
list
(
self
.
_model
.
critic
.
modules
())
+
list
(
self
.
_model
.
actor
.
modules
()):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
# orthogonal initialization
# orthogonal initialization
...
@@ -131,8 +131,6 @@ class PPOPolicy(Policy):
...
@@ -131,8 +131,6 @@ class PPOPolicy(Policy):
# Main model
# Main model
self
.
_learn_model
.
reset
()
self
.
_learn_model
.
reset
()
from
torch.optim.lr_scheduler
import
LambdaLR
# self._lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lambda epoch: 1 - epoch / 1500.0)
def
_forward_learn
(
self
,
data
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
def
_forward_learn
(
self
,
data
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
r
"""
r
"""
...
@@ -163,16 +161,17 @@ class PPOPolicy(Policy):
...
@@ -163,16 +161,17 @@ class PPOPolicy(Policy):
for
epoch
in
range
(
self
.
_cfg
.
learn
.
epoch_per_collect
):
for
epoch
in
range
(
self
.
_cfg
.
learn
.
epoch_per_collect
):
if
self
.
_recompute_adv
:
if
self
.
_recompute_adv
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
obs
=
torch
.
cat
([
data
[
'obs'
],
data
[
'next_obs'
][
-
1
:]])
#
obs = torch.cat([data['obs'], data['next_obs'][-1:]])
value
=
self
.
_learn_model
.
forward
(
obs
,
mode
=
'compute_critic'
)[
'value'
]
value
=
self
.
_learn_model
.
forward
(
data
[
'obs'
]
,
mode
=
'compute_critic'
)[
'value'
]
next_value
=
self
.
_learn_model
.
forward
(
data
[
'next_obs'
],
mode
=
'compute_critic'
)[
'value'
]
if
self
.
_value_norm
:
if
self
.
_value_norm
:
value
*=
self
.
_running_mean_std
.
std
value
*=
self
.
_running_mean_std
.
std
next_value
*=
self
.
_running_mean_std
.
std
gae_data_
=
gae_data
(
value
,
data
[
'reward'
],
data
[
'done'
])
gae_data_
=
gae_data
(
value
,
next_value
,
data
[
'reward'
],
data
[
'done'
])
# GAE need (T, B) shape input and return (T, B) output
# GAE need (T, B) shape input and return (T, B) output
data
[
'adv'
]
=
gae
(
gae_data_
,
self
.
_gamma
,
self
.
_gae_lambda
)
data
[
'adv'
]
=
gae
(
gae_data_
,
self
.
_gamma
,
self
.
_gae_lambda
)
value
=
value
[:
-
1
]
#
value = value[:-1]
unnormalized_returns
=
value
+
data
[
'adv'
]
unnormalized_returns
=
value
+
data
[
'adv'
]
if
self
.
_value_norm
:
if
self
.
_value_norm
:
...
@@ -186,10 +185,8 @@ class PPOPolicy(Policy):
...
@@ -186,10 +185,8 @@ class PPOPolicy(Policy):
for
batch
in
split_data_generator
(
data
,
self
.
_cfg
.
learn
.
batch_size
,
shuffle
=
True
):
for
batch
in
split_data_generator
(
data
,
self
.
_cfg
.
learn
.
batch_size
,
shuffle
=
True
):
output
=
self
.
_learn_model
.
forward
(
batch
[
'obs'
],
mode
=
'compute_actor_critic'
)
output
=
self
.
_learn_model
.
forward
(
batch
[
'obs'
],
mode
=
'compute_actor_critic'
)
adv
=
batch
[
'adv'
]
adv
=
batch
[
'adv'
]
# with torch.no_grad():
# batch['return'] = batch['value'] + adv
if
self
.
_adv_norm
:
if
self
.
_adv_norm
:
# Normalize advantage in a t
otal t
rain_batch
# Normalize advantage in a train_batch
adv
=
(
adv
-
adv
.
mean
())
/
(
adv
.
std
()
+
1e-8
)
adv
=
(
adv
-
adv
.
mean
())
/
(
adv
.
std
()
+
1e-8
)
# Calculate ppo error
# Calculate ppo error
...
@@ -231,11 +228,9 @@ class PPOPolicy(Policy):
...
@@ -231,11 +228,9 @@ class PPOPolicy(Policy):
{
{
'mu_mean'
:
output
[
'logit'
][
0
].
mean
().
item
(),
'mu_mean'
:
output
[
'logit'
][
0
].
mean
().
item
(),
'sigma_mean'
:
output
[
'logit'
][
1
].
mean
().
item
(),
'sigma_mean'
:
output
[
'logit'
][
1
].
mean
().
item
(),
# 'sigma_grad': self._model.actor_head.log_sigma_param.grad.data.mean().item(),
}
}
)
)
return_infos
.
append
(
return_info
)
return_infos
.
append
(
return_info
)
# self._lr_scheduler.step()
return
return_infos
return
return_infos
def
_state_dict_learn
(
self
)
->
Dict
[
str
,
Any
]:
def
_state_dict_learn
(
self
)
->
Dict
[
str
,
Any
]:
...
@@ -284,7 +279,6 @@ class PPOPolicy(Policy):
...
@@ -284,7 +279,6 @@ class PPOPolicy(Policy):
if
self
.
_continuous
:
if
self
.
_continuous
:
(
mu
,
sigma
),
value
=
output
[
'logit'
],
output
[
'value'
]
(
mu
,
sigma
),
value
=
output
[
'logit'
],
output
[
'value'
]
dist
=
Independent
(
Normal
(
mu
,
sigma
),
1
)
dist
=
Independent
(
Normal
(
mu
,
sigma
),
1
)
# action = torch.clamp(dist.sample(), min=-1, max=1)
output
[
'action'
]
=
dist
.
sample
()
output
[
'action'
]
=
dist
.
sample
()
if
self
.
_cuda
:
if
self
.
_cuda
:
output
=
to_device
(
output
,
'cpu'
)
output
=
to_device
(
output
,
'cpu'
)
...
@@ -326,8 +320,7 @@ class PPOPolicy(Policy):
...
@@ -326,8 +320,7 @@ class PPOPolicy(Policy):
data
=
to_device
(
data
,
self
.
_device
)
data
=
to_device
(
data
,
self
.
_device
)
# adder is defined in _init_collect
# adder is defined in _init_collect
if
self
.
_cfg
.
learn
.
ignore_done
:
if
self
.
_cfg
.
learn
.
ignore_done
:
for
i
in
range
(
len
(
data
)):
data
[
-
1
][
'done'
]
=
False
data
[
i
][
'done'
]
=
False
if
data
[
-
1
][
'done'
]:
if
data
[
-
1
][
'done'
]:
last_value
=
torch
.
zeros
(
1
)
last_value
=
torch
.
zeros
(
1
)
...
@@ -346,7 +339,7 @@ class PPOPolicy(Policy):
...
@@ -346,7 +339,7 @@ class PPOPolicy(Policy):
for
i
in
range
(
len
(
data
)):
for
i
in
range
(
len
(
data
)):
data
[
i
][
'value'
]
/=
self
.
_running_mean_std
.
std
data
[
i
][
'value'
]
/=
self
.
_running_mean_std
.
std
# remove next_obs for save memory when not recompute adv
# remove next_obs for save memory when not recompute adv
if
not
self
.
_recompute_adv
:
if
not
self
.
_recompute_adv
:
for
i
in
range
(
len
(
data
)):
for
i
in
range
(
len
(
data
)):
data
[
i
].
pop
(
'next_obs'
)
data
[
i
].
pop
(
'next_obs'
)
...
@@ -383,8 +376,6 @@ class PPOPolicy(Policy):
...
@@ -383,8 +376,6 @@ class PPOPolicy(Policy):
output
=
self
.
_eval_model
.
forward
(
data
,
mode
=
'compute_actor'
)
output
=
self
.
_eval_model
.
forward
(
data
,
mode
=
'compute_actor'
)
if
self
.
_continuous
:
if
self
.
_continuous
:
(
mu
,
sigma
)
=
output
[
'logit'
]
(
mu
,
sigma
)
=
output
[
'logit'
]
# dist = Independent(Normal(mu, sigma), 1)
# action = torch.clamp(dist.sample(), min=-1, max=1)
output
.
update
({
'action'
:
mu
})
output
.
update
({
'action'
:
mu
})
if
self
.
_cuda
:
if
self
.
_cuda
:
output
=
to_device
(
output
,
'cpu'
)
output
=
to_device
(
output
,
'cpu'
)
...
...
ding/rl_utils/adder.py
浏览文件 @
e30a3d3c
...
@@ -32,12 +32,13 @@ class Adder(object):
...
@@ -32,12 +32,13 @@ class Adder(object):
Returns:
Returns:
- data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv'
- data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv'
"""
"""
value
=
torch
.
stack
([
d
[
'value'
]
for
d
in
data
]
+
[
last_value
])
value
=
torch
.
stack
([
d
[
'value'
]
for
d
in
data
])
next_value
=
torch
.
stack
([
d
[
'value'
]
for
d
in
data
][
1
:]
+
[
last_value
])
reward
=
torch
.
stack
([
d
[
'reward'
]
for
d
in
data
])
reward
=
torch
.
stack
([
d
[
'reward'
]
for
d
in
data
])
if
cuda
:
if
cuda
:
value
=
value
.
cuda
()
value
=
value
.
cuda
()
reward
=
reward
.
cuda
()
reward
=
reward
.
cuda
()
adv
=
gae
(
gae_data
(
value
,
reward
,
None
),
gamma
,
gae_lambda
)
adv
=
gae
(
gae_data
(
value
,
next_value
,
reward
,
None
),
gamma
,
gae_lambda
)
if
cuda
:
if
cuda
:
adv
=
adv
.
cpu
()
adv
=
adv
.
cpu
()
for
i
in
range
(
len
(
data
)):
for
i
in
range
(
len
(
data
)):
...
...
ding/rl_utils/gae.py
浏览文件 @
e30a3d3c
...
@@ -2,7 +2,7 @@ from collections import namedtuple
...
@@ -2,7 +2,7 @@ from collections import namedtuple
import
torch
import
torch
from
ding.hpc_rl
import
hpc_wrapper
from
ding.hpc_rl
import
hpc_wrapper
gae_data
=
namedtuple
(
'gae_data'
,
[
'value'
,
'reward'
,
'done'
])
gae_data
=
namedtuple
(
'gae_data'
,
[
'value'
,
'
next_value'
,
'
reward'
,
'done'
])
def
shape_fn_gae
(
args
,
kwargs
):
def
shape_fn_gae
(
args
,
kwargs
):
...
@@ -43,17 +43,16 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
...
@@ -43,17 +43,16 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
value_{T+1} should be 0 if this trajectory reached a terminal state(done=True), otherwise we use value
value_{T+1} should be 0 if this trajectory reached a terminal state(done=True), otherwise we use value
function, this operation is implemented in collector for packing trajectory.
function, this operation is implemented in collector for packing trajectory.
"""
"""
value
,
reward
,
done
=
data
value
,
next_value
,
reward
,
done
=
data
if
done
is
None
:
if
done
is
None
:
d
elta
=
reward
+
gamma
*
value
[
1
:]
-
value
[:
-
1
]
d
one
=
torch
.
zeros_like
(
reward
,
device
=
reward
.
device
)
else
:
delta
=
reward
+
(
1
-
done
)
*
gamma
*
value
[
1
:]
-
value
[:
-
1
]
delta
=
reward
+
(
1
-
done
)
*
gamma
*
next_value
-
value
factor
=
gamma
*
lambda_
factor
=
gamma
*
lambda_
adv
=
torch
.
zeros_like
(
reward
)
adv
=
torch
.
zeros_like
(
reward
,
device
=
reward
.
device
)
gae_item
=
0.
gae_item
=
0.
denom
=
0.
for
t
in
reversed
(
range
(
reward
.
shape
[
0
])):
for
t
in
reversed
(
range
(
reward
.
shape
[
0
])):
denom
=
1
+
lambda_
*
denom
gae_item
=
delta
[
t
]
+
factor
*
gae_item
*
(
1
-
done
[
t
])
gae_item
=
denom
*
delta
[
t
]
+
factor
*
gae_item
adv
[
t
]
+=
gae_item
adv
[
t
]
+=
gae_item
/
denom
return
adv
return
adv
dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
浏览文件 @
e30a3d3c
...
@@ -2,7 +2,7 @@ from easydict import EasyDict
...
@@ -2,7 +2,7 @@ from easydict import EasyDict
pendulum_ppo_config
=
dict
(
pendulum_ppo_config
=
dict
(
env
=
dict
(
env
=
dict
(
collector_env_num
=
1
6
,
collector_env_num
=
1
,
evaluator_env_num
=
5
,
evaluator_env_num
=
5
,
act_scale
=
True
,
act_scale
=
True
,
n_evaluator_episode
=
5
,
n_evaluator_episode
=
5
,
...
@@ -12,7 +12,7 @@ pendulum_ppo_config = dict(
...
@@ -12,7 +12,7 @@ pendulum_ppo_config = dict(
cuda
=
False
,
cuda
=
False
,
on_policy
=
True
,
on_policy
=
True
,
continuous
=
True
,
continuous
=
True
,
recompute_adv
=
Tru
e
,
recompute_adv
=
Fals
e
,
model
=
dict
(
model
=
dict
(
obs_shape
=
3
,
obs_shape
=
3
,
action_shape
=
1
,
action_shape
=
1
,
...
@@ -20,25 +20,25 @@ pendulum_ppo_config = dict(
...
@@ -20,25 +20,25 @@ pendulum_ppo_config = dict(
continuous
=
True
,
continuous
=
True
,
actor_head_layer_num
=
0
,
actor_head_layer_num
=
0
,
critic_head_layer_num
=
0
,
critic_head_layer_num
=
0
,
sigma_type
=
'
fix
ed'
,
sigma_type
=
'
condition
ed'
,
bound_type
=
'tanh'
,
bound_type
=
'tanh'
,
),
),
learn
=
dict
(
learn
=
dict
(
epoch_per_collect
=
10
,
epoch_per_collect
=
10
,
batch_size
=
128
,
batch_size
=
32
,
learning_rate
=
1e-3
,
learning_rate
=
3e-5
,
value_weight
=
0.5
,
value_weight
=
0.5
,
entropy_weight
=
0.0
,
entropy_weight
=
0.0
,
clip_ratio
=
0.2
,
clip_ratio
=
0.2
,
adv_norm
=
Tru
e
,
adv_norm
=
Fals
e
,
value_norm
=
True
,
value_norm
=
True
,
ignore_done
=
Tru
e
,
ignore_done
=
Fals
e
,
),
),
collect
=
dict
(
collect
=
dict
(
n_sample
=
3
200
,
n_sample
=
200
,
unroll_len
=
1
,
unroll_len
=
1
,
discount_factor
=
0.9
5
,
discount_factor
=
0.9
,
gae_lambda
=
0.95
,
gae_lambda
=
1.
,
),
),
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
200
,
))
eval
=
dict
(
evaluator
=
dict
(
eval_freq
=
200
,
))
),
),
...
...
dizoo/mujoco/config/hopper_ppo_default_config.py
浏览文件 @
e30a3d3c
...
@@ -5,7 +5,7 @@ hopper_ppo_default_config = dict(
...
@@ -5,7 +5,7 @@ hopper_ppo_default_config = dict(
env_id
=
'Hopper-v3'
,
env_id
=
'Hopper-v3'
,
norm_obs
=
dict
(
use_norm
=
False
,
),
norm_obs
=
dict
(
use_norm
=
False
,
),
norm_reward
=
dict
(
use_norm
=
False
,
),
norm_reward
=
dict
(
use_norm
=
False
,
),
collector_env_num
=
64
,
collector_env_num
=
8
,
evaluator_env_num
=
10
,
evaluator_env_num
=
10
,
use_act_scale
=
True
,
use_act_scale
=
True
,
n_evaluator_episode
=
10
,
n_evaluator_episode
=
10
,
...
@@ -14,6 +14,7 @@ hopper_ppo_default_config = dict(
...
@@ -14,6 +14,7 @@ hopper_ppo_default_config = dict(
policy
=
dict
(
policy
=
dict
(
cuda
=
True
,
cuda
=
True
,
on_policy
=
True
,
on_policy
=
True
,
recompute_adv
=
True
,
model
=
dict
(
model
=
dict
(
obs_shape
=
11
,
obs_shape
=
11
,
action_shape
=
3
,
action_shape
=
3
,
...
@@ -28,7 +29,7 @@ hopper_ppo_default_config = dict(
...
@@ -28,7 +29,7 @@ hopper_ppo_default_config = dict(
entropy_weight
=
0.0
,
entropy_weight
=
0.0
,
clip_ratio
=
0.2
,
clip_ratio
=
0.2
,
adv_norm
=
True
,
adv_norm
=
True
,
recompute_adv
=
True
,
value_norm
=
True
,
),
),
collect
=
dict
(
collect
=
dict
(
n_sample
=
2048
,
n_sample
=
2048
,
...
...
dizoo/mujoco/entry/__init__.py
0 → 100644
浏览文件 @
e30a3d3c
dizoo/mujoco/entry/mujoco_ppo_main.py
0 → 100644
浏览文件 @
e30a3d3c
import
os
import
gym
from
tensorboardX
import
SummaryWriter
from
easydict
import
EasyDict
from
ding.config
import
compile_config
from
ding.worker
import
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
NaiveReplayBuffer
from
ding.envs
import
BaseEnvManager
,
DingEnvWrapper
from
ding.policy
import
PPOPolicy
from
ding.model
import
VAC
from
ding.utils
import
set_pkg_seed
from
dizoo.classic_control.pendulum.envs
import
PendulumEnv
from
dizoo.mujoco.envs.mujoco_env
import
MujocoEnv
from
dizoo.classic_control.pendulum.config.pendulum_ppo_config
import
pendulum_ppo_config
from
dizoo.mujoco.config.hopper_ppo_default_config
import
hopper_ppo_default_config
def
main
(
cfg
,
seed
=
0
,
max_iterations
=
int
(
1e10
)):
cfg
=
compile_config
(
cfg
,
BaseEnvManager
,
PPOPolicy
,
BaseLearner
,
SampleCollector
,
BaseSerialEvaluator
,
NaiveReplayBuffer
,
save_cfg
=
True
)
collector_env_num
,
evaluator_env_num
=
cfg
.
env
.
collector_env_num
,
cfg
.
env
.
evaluator_env_num
collector_env
=
BaseEnvManager
(
env_fn
=
[
lambda
:
MujocoEnv
(
cfg
.
env
)
for
_
in
range
(
collector_env_num
)],
cfg
=
cfg
.
env
.
manager
)
evaluator_env
=
BaseEnvManager
(
env_fn
=
[
lambda
:
MujocoEnv
(
cfg
.
env
)
for
_
in
range
(
evaluator_env_num
)],
cfg
=
cfg
.
env
.
manager
)
collector_env
.
seed
(
seed
,
dynamic_seed
=
True
)
evaluator_env
.
seed
(
seed
,
dynamic_seed
=
False
)
set_pkg_seed
(
seed
,
use_cuda
=
cfg
.
policy
.
cuda
)
model
=
VAC
(
**
cfg
.
policy
.
model
)
policy
=
PPOPolicy
(
cfg
.
policy
,
model
=
model
)
tb_logger
=
SummaryWriter
(
os
.
path
.
join
(
'./log/'
,
'serial'
))
learner
=
BaseLearner
(
cfg
.
policy
.
learn
.
learner
,
policy
.
learn_mode
,
tb_logger
)
collector
=
SampleCollector
(
cfg
.
policy
.
collect
.
collector
,
collector_env
,
policy
.
collect_mode
,
tb_logger
)
evaluator
=
BaseSerialEvaluator
(
cfg
.
policy
.
eval
.
evaluator
,
evaluator_env
,
policy
.
eval_mode
,
tb_logger
)
for
_
in
range
(
max_iterations
):
if
evaluator
.
should_eval
(
learner
.
train_iter
):
stop
,
reward
=
evaluator
.
eval
(
learner
.
save_checkpoint
,
learner
.
train_iter
,
collector
.
envstep
)
if
stop
:
break
new_data
=
collector
.
collect
(
train_iter
=
learner
.
train_iter
)
learner
.
train
(
new_data
,
collector
.
envstep
)
if
__name__
==
"__main__"
:
main
(
hopper_ppo_default_config
)
dizoo/mujoco/envs/mujoco_wrappers.py
浏览文件 @
e30a3d3c
import
gym
import
gym
import
numpy
as
np
import
numpy
as
np
import
pybulletgym
from
ding.envs
import
ObsNormEnv
,
RewardNormEnv
from
ding.envs
import
ObsNormEnv
,
RewardNormEnv
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录