Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
f70d3ddb
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f70d3ddb
编写于
11月 03, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(nyz): fix wqmix target_model state_dict bug and polish mujoco model env
上级
db642fd3
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
69 addition
and
26 deletion
+69
-26
ding/policy/wqmix.py
ding/policy/wqmix.py
+28
-0
dizoo/mujoco/config/hopper_ppo_default_config.py
dizoo/mujoco/config/hopper_ppo_default_config.py
+1
-3
dizoo/mujoco/config/hopper_td3_data_generation_config.py
dizoo/mujoco/config/hopper_td3_data_generation_config.py
+1
-1
dizoo/mujoco/config/sac_halfcheetah_mbpo_default_config.py
dizoo/mujoco/config/sac_halfcheetah_mbpo_default_config.py
+1
-1
dizoo/mujoco/config/sac_hopper_mbpo_default_config.py
dizoo/mujoco/config/sac_hopper_mbpo_default_config.py
+1
-1
dizoo/mujoco/entry/mujoco_td3_bc_main.py
dizoo/mujoco/entry/mujoco_td3_bc_main.py
+9
-2
dizoo/mujoco/envs/mujoco_env.py
dizoo/mujoco/envs/mujoco_env.py
+2
-2
dizoo/mujoco/envs/mujoco_model_env.py
dizoo/mujoco/envs/mujoco_model_env.py
+26
-16
未找到文件。
ding/policy/wqmix.py
浏览文件 @
f70d3ddb
...
...
@@ -286,3 +286,31 @@ class WQMIXPolicy(QMIXPolicy):
by import_names path. For WQMIX, ``ding.model.template.wqmix``
"""
return
'wqmix'
,
[
'ding.model.template.wqmix'
]
def
_state_dict_learn
(
self
)
->
Dict
[
str
,
Any
]:
r
"""
Overview:
Return the state_dict of learn mode, usually including model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
"""
return
{
'model'
:
self
.
_learn_model
.
state_dict
(),
'optimizer'
:
self
.
_optimizer
.
state_dict
(),
'optimizer_star'
:
self
.
_optimizer_star
.
state_dict
(),
}
def
_load_state_dict_learn
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
r
"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
complicated operation.
"""
self
.
_learn_model
.
load_state_dict
(
state_dict
[
'model'
])
self
.
_optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
self
.
_optimizer_star
.
load_state_dict
(
state_dict
[
'optimizer_star'
])
dizoo/mujoco/config/hopper_ppo_default_config.py
浏览文件 @
f70d3ddb
...
...
@@ -48,9 +48,7 @@ hopper_ppo_create_default_config = dict(
import_names
=
[
'dizoo.mujoco.envs.mujoco_env'
],
),
env_manager
=
dict
(
type
=
'subprocess'
),
policy
=
dict
(
type
=
'ppo'
,
),
policy
=
dict
(
type
=
'ppo'
,
),
)
hopper_ppo_create_default_config
=
EasyDict
(
hopper_ppo_create_default_config
)
create_config
=
hopper_ppo_create_default_config
dizoo/mujoco/config/hopper_td3_data_generation_config.py
浏览文件 @
f70d3ddb
...
...
@@ -37,7 +37,7 @@ halfcheetah_td3_default_config = dict(
min
=-
0.5
,
max
=
0.5
,
),
learner
=
dict
(
learner
=
dict
(
load_path
=
'./td3/ckpt/ckpt_best.pth.tar'
,
hook
=
dict
(
load_ckpt_before_run
=
'./td3/ckpt/ckpt_best.pth.tar'
,
...
...
dizoo/mujoco/config/sac_halfcheetah_mbpo_default_config.py
浏览文件 @
f70d3ddb
...
...
@@ -22,7 +22,7 @@ y0 = rollout_length_min
y1
=
rollout_length_max
w
=
(
rollout_length_max
-
rollout_length_min
)
/
(
rollout_end_step
-
rollout_start_step
)
b
=
rollout_length_min
set_rollout_length
=
lambda
x
:
int
(
min
(
max
(
w
*
(
x
-
x0
)
+
b
,
y0
),
y1
)
)
set_rollout_length
=
lambda
x
:
int
(
min
(
max
(
w
*
(
x
-
x0
)
+
b
,
y0
),
y1
)
)
set_buffer_size
=
lambda
x
:
set_rollout_length
(
x
)
*
rollout_batch_size
*
rollout_retain
main_config
=
dict
(
...
...
dizoo/mujoco/config/sac_hopper_mbpo_default_config.py
浏览文件 @
f70d3ddb
...
...
@@ -22,7 +22,7 @@ y0 = rollout_length_min
y1
=
rollout_length_max
w
=
(
rollout_length_max
-
rollout_length_min
)
/
(
rollout_end_step
-
rollout_start_step
)
b
=
rollout_length_min
set_rollout_length
=
lambda
x
:
int
(
min
(
max
(
w
*
(
x
-
x0
)
+
b
,
y0
),
y1
)
)
set_rollout_length
=
lambda
x
:
int
(
min
(
max
(
w
*
(
x
-
x0
)
+
b
,
y0
),
y1
)
)
set_buffer_size
=
lambda
x
:
set_rollout_length
(
x
)
*
rollout_batch_size
*
rollout_retain
main_config
=
dict
(
...
...
dizoo/mujoco/entry/mujoco_td3_bc_main.py
浏览文件 @
f70d3ddb
...
...
@@ -23,6 +23,7 @@ def eval_ckpt(args):
eval
(
config
,
seed
=
args
.
seed
,
load_path
=
main_config
.
policy
.
learn
.
learner
.
hook
.
load_ckpt_before_run
)
# eval(config, seed=args.seed, state_dict=state_dict)
def
generate
(
args
):
main_config
.
exp_name
=
'td3'
main_config
.
policy
.
learn
.
learner
.
load_path
=
'./td3/ckpt/ckpt_best.pth.tar'
...
...
@@ -30,8 +31,14 @@ def generate(args):
main_config
.
policy
.
collect
.
data_type
=
'hdf5'
config
=
deepcopy
([
main_config
,
create_config
])
state_dict
=
torch
.
load
(
main_config
.
policy
.
learn
.
learner
.
load_path
,
map_location
=
'cpu'
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
collect_demo_data
(
config
,
collect_count
=
main_config
.
policy
.
other
.
replay_buffer
.
replay_buffer_size
,
seed
=
args
.
seed
,
expert_data_path
=
main_config
.
policy
.
collect
.
save_path
,
state_dict
=
state_dict
)
def
train_expert
(
args
):
from
dizoo.mujoco.config.hopper_td3_default_config
import
main_config
,
create_config
...
...
dizoo/mujoco/envs/mujoco_env.py
浏览文件 @
f70d3ddb
...
...
@@ -313,8 +313,8 @@ class MujocoEnv(BaseEnv):
info
.
rew_space
.
shape
=
rew_shape
return
info
else
:
raise
NotImplementedError
(
'{} not found in MUJOCO_INFO_DICT [{}]'
\
.
format
(
self
.
_cfg
.
env_id
,
MUJOCO_INFO_DICT
.
keys
()
))
keys
=
MUJOCO_INFO_DICT
.
keys
()
raise
NotImplementedError
(
'{} not found in MUJOCO_INFO_DICT [{}]'
.
format
(
self
.
_cfg
.
env_id
,
keys
))
def
_make_env
(
self
,
only_info
=
False
):
return
wrap_mujoco
(
...
...
dizoo/mujoco/envs/mujoco_model_env.py
浏览文件 @
f70d3ddb
from
typing
import
Any
,
Union
,
List
,
Callable
,
Dict
import
copy
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
ding.envs
import
BaseEnv
,
BaseEnvTimestep
,
BaseEnvInfo
,
update_shape
...
...
@@ -11,8 +12,10 @@ from .mujoco_wrappers import wrap_mujoco
from
ding.utils
import
ENV_REGISTRY
from
ding.worker.collector.base_serial_collector
import
to_tensor_transitions
@
ENV_REGISTRY
.
register
(
'mujoco_model'
)
class
MujocoModelEnv
(
object
):
def
__init__
(
self
,
env_id
:
str
,
set_rollout_length
:
Callable
,
rollout_batch_size
:
int
=
100000
):
self
.
env_id
=
env_id
self
.
rollout_batch_size
=
rollout_batch_size
...
...
@@ -41,9 +44,9 @@ class MujocoModelEnv(object):
done
=
~
not_done
return
done
elif
'walker_'
in
self
.
env_id
:
torso_height
=
next_obs
[:,
-
2
]
torso_height
=
next_obs
[:,
-
2
]
torso_ang
=
next_obs
[:,
-
1
]
if
'walker_7'
in
env_id
or
'walker_5'
in
env_id
:
if
'walker_7'
in
self
.
env_id
or
'walker_5'
in
self
.
env_id
:
offset
=
0.
else
:
offset
=
0.26
...
...
@@ -57,14 +60,20 @@ class MujocoModelEnv(object):
done
=
torch
.
zeros_like
(
next_obs
.
sum
(
-
1
)).
bool
()
return
done
def
rollout
(
self
,
env_model
:
nn
.
Module
,
policy
:
'Policy'
,
replay_buffer
:
'IBuffer'
,
imagine_buffer
:
'IBuffer'
,
envstep
:
int
,
cur_learner_iter
:
int
)
->
None
:
# This function samples from the replay_buffer, rollouts to generate new data, and push them into the imagine_buffer
def
rollout
(
self
,
env_model
:
nn
.
Module
,
policy
:
'Policy'
,
# noqa
replay_buffer
:
'IBuffer'
,
# noqa
imagine_buffer
:
'IBuffer'
,
# noqa
envstep
:
int
,
cur_learner_iter
:
int
)
->
None
:
"""
Overview:
This function samples from the replay_buffer, rollouts to generate new data,
and push them into the imagine_buffer
"""
# set rollout length
rollout_length
=
self
.
_set_rollout_length
(
envstep
)
# load data
...
...
@@ -82,9 +91,7 @@ class MujocoModelEnv(object):
timesteps
=
self
.
step
(
obs
,
actions
,
env_model
)
obs_new
=
{}
for
id
,
timestep
in
timesteps
.
items
():
transition
=
policy
.
process_transition
(
obs
[
id
],
policy_output
[
id
],
timestep
)
transition
=
policy
.
process_transition
(
obs
[
id
],
policy_output
[
id
],
timestep
)
transition
[
'collect_iter'
]
=
cur_learner_iter
buffer
[
id
].
append
(
transition
)
if
not
timestep
.
done
:
...
...
@@ -102,9 +109,12 @@ class MujocoModelEnv(object):
def
step
(
self
,
obs
:
Dict
,
act
:
Dict
,
env_model
:
nn
.
Module
)
->
Dict
:
# This function has the same input and output format as env manager's step
data_id
=
list
(
obs
.
keys
())
obs
=
torch
.
stack
([
obs
[
id
]
for
id
in
data_id
],
dim
=
0
)
act
=
torch
.
stack
([
act
[
id
]
for
id
in
data_id
],
dim
=
0
)
obs
=
torch
.
stack
([
obs
[
id
]
for
id
in
data_id
],
dim
=
0
)
act
=
torch
.
stack
([
act
[
id
]
for
id
in
data_id
],
dim
=
0
)
rewards
,
next_obs
=
env_model
.
predict
(
obs
,
act
)
terminals
=
self
.
termination_fn
(
next_obs
)
timesteps
=
{
id
:
BaseEnvTimestep
(
n
,
r
,
d
,{})
for
id
,
n
,
r
,
d
in
zip
(
data_id
,
next_obs
.
numpy
(),
rewards
.
numpy
(),
terminals
.
numpy
())}
timesteps
=
{
id
:
BaseEnvTimestep
(
n
,
r
,
d
,
{})
for
id
,
n
,
r
,
d
in
zip
(
data_id
,
next_obs
.
numpy
(),
rewards
.
numpy
(),
terminals
.
numpy
())
}
return
timesteps
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录