提交 f70d3ddb 编写于 作者: N niuyazhe

fix(nyz): fix wqmix target_model state_dict bug and polish mujoco model env

上级 db642fd3
...@@ -286,3 +286,31 @@ class WQMIXPolicy(QMIXPolicy): ...@@ -286,3 +286,31 @@ class WQMIXPolicy(QMIXPolicy):
by import_names path. For WQMIX, ``ding.model.template.wqmix`` by import_names path. For WQMIX, ``ding.model.template.wqmix``
""" """
return 'wqmix', ['ding.model.template.wqmix'] return 'wqmix', ['ding.model.template.wqmix']
def _state_dict_learn(self) -> Dict[str, Any]:
r"""
Overview:
Return the state_dict of learn mode, usually including model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
"""
return {
'model': self._learn_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
'optimizer_star': self._optimizer_star.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
r"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
complicated operation.
"""
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
self._optimizer_star.load_state_dict(state_dict['optimizer_star'])
...@@ -48,9 +48,7 @@ hopper_ppo_create_default_config = dict( ...@@ -48,9 +48,7 @@ hopper_ppo_create_default_config = dict(
import_names=['dizoo.mujoco.envs.mujoco_env'], import_names=['dizoo.mujoco.envs.mujoco_env'],
), ),
env_manager=dict(type='subprocess'), env_manager=dict(type='subprocess'),
policy=dict( policy=dict(type='ppo', ),
type='ppo',
),
) )
hopper_ppo_create_default_config = EasyDict(hopper_ppo_create_default_config) hopper_ppo_create_default_config = EasyDict(hopper_ppo_create_default_config)
create_config = hopper_ppo_create_default_config create_config = hopper_ppo_create_default_config
...@@ -37,7 +37,7 @@ halfcheetah_td3_default_config = dict( ...@@ -37,7 +37,7 @@ halfcheetah_td3_default_config = dict(
min=-0.5, min=-0.5,
max=0.5, max=0.5,
), ),
learner = dict( learner=dict(
load_path='./td3/ckpt/ckpt_best.pth.tar', load_path='./td3/ckpt/ckpt_best.pth.tar',
hook=dict( hook=dict(
load_ckpt_before_run='./td3/ckpt/ckpt_best.pth.tar', load_ckpt_before_run='./td3/ckpt/ckpt_best.pth.tar',
......
...@@ -22,7 +22,7 @@ y0 = rollout_length_min ...@@ -22,7 +22,7 @@ y0 = rollout_length_min
y1 = rollout_length_max y1 = rollout_length_max
w = (rollout_length_max - rollout_length_min) / (rollout_end_step - rollout_start_step) w = (rollout_length_max - rollout_length_min) / (rollout_end_step - rollout_start_step)
b = rollout_length_min b = rollout_length_min
set_rollout_length = lambda x: int( min( max( w * (x - x0) + b, y0 ), y1 ) ) set_rollout_length = lambda x: int(min(max(w * (x - x0) + b, y0), y1))
set_buffer_size = lambda x: set_rollout_length(x) * rollout_batch_size * rollout_retain set_buffer_size = lambda x: set_rollout_length(x) * rollout_batch_size * rollout_retain
main_config = dict( main_config = dict(
......
...@@ -22,7 +22,7 @@ y0 = rollout_length_min ...@@ -22,7 +22,7 @@ y0 = rollout_length_min
y1 = rollout_length_max y1 = rollout_length_max
w = (rollout_length_max - rollout_length_min) / (rollout_end_step - rollout_start_step) w = (rollout_length_max - rollout_length_min) / (rollout_end_step - rollout_start_step)
b = rollout_length_min b = rollout_length_min
set_rollout_length = lambda x: int( min( max( w * (x - x0) + b, y0 ), y1 ) ) set_rollout_length = lambda x: int(min(max(w * (x - x0) + b, y0), y1))
set_buffer_size = lambda x: set_rollout_length(x) * rollout_batch_size * rollout_retain set_buffer_size = lambda x: set_rollout_length(x) * rollout_batch_size * rollout_retain
main_config = dict( main_config = dict(
......
...@@ -23,6 +23,7 @@ def eval_ckpt(args): ...@@ -23,6 +23,7 @@ def eval_ckpt(args):
eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run) eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
# eval(config, seed=args.seed, state_dict=state_dict) # eval(config, seed=args.seed, state_dict=state_dict)
def generate(args): def generate(args):
main_config.exp_name = 'td3' main_config.exp_name = 'td3'
main_config.policy.learn.learner.load_path = './td3/ckpt/ckpt_best.pth.tar' main_config.policy.learn.learner.load_path = './td3/ckpt/ckpt_best.pth.tar'
...@@ -30,8 +31,14 @@ def generate(args): ...@@ -30,8 +31,14 @@ def generate(args):
main_config.policy.collect.data_type = 'hdf5' main_config.policy.collect.data_type = 'hdf5'
config = deepcopy([main_config, create_config]) config = deepcopy([main_config, create_config])
state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu') state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
collect_demo_data(config, collect_count=main_config.policy.other.replay_buffer.replay_buffer_size, collect_demo_data(
seed=args.seed, expert_data_path=main_config.policy.collect.save_path, state_dict=state_dict) config,
collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
seed=args.seed,
expert_data_path=main_config.policy.collect.save_path,
state_dict=state_dict
)
def train_expert(args): def train_expert(args):
from dizoo.mujoco.config.hopper_td3_default_config import main_config, create_config from dizoo.mujoco.config.hopper_td3_default_config import main_config, create_config
......
...@@ -313,8 +313,8 @@ class MujocoEnv(BaseEnv): ...@@ -313,8 +313,8 @@ class MujocoEnv(BaseEnv):
info.rew_space.shape = rew_shape info.rew_space.shape = rew_shape
return info return info
else: else:
raise NotImplementedError('{} not found in MUJOCO_INFO_DICT [{}]'\ keys = MUJOCO_INFO_DICT.keys()
.format(self._cfg.env_id, MUJOCO_INFO_DICT.keys())) raise NotImplementedError('{} not found in MUJOCO_INFO_DICT [{}]'.format(self._cfg.env_id, keys))
def _make_env(self, only_info=False): def _make_env(self, only_info=False):
return wrap_mujoco( return wrap_mujoco(
......
from typing import Any, Union, List, Callable, Dict from typing import Any, Union, List, Callable, Dict
import copy import copy
import torch import torch
import torch.nn as nn
import numpy as np import numpy as np
from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
...@@ -11,8 +12,10 @@ from .mujoco_wrappers import wrap_mujoco ...@@ -11,8 +12,10 @@ from .mujoco_wrappers import wrap_mujoco
from ding.utils import ENV_REGISTRY from ding.utils import ENV_REGISTRY
from ding.worker.collector.base_serial_collector import to_tensor_transitions from ding.worker.collector.base_serial_collector import to_tensor_transitions
@ENV_REGISTRY.register('mujoco_model') @ENV_REGISTRY.register('mujoco_model')
class MujocoModelEnv(object): class MujocoModelEnv(object):
def __init__(self, env_id: str, set_rollout_length: Callable, rollout_batch_size: int = 100000): def __init__(self, env_id: str, set_rollout_length: Callable, rollout_batch_size: int = 100000):
self.env_id = env_id self.env_id = env_id
self.rollout_batch_size = rollout_batch_size self.rollout_batch_size = rollout_batch_size
...@@ -41,9 +44,9 @@ class MujocoModelEnv(object): ...@@ -41,9 +44,9 @@ class MujocoModelEnv(object):
done = ~not_done done = ~not_done
return done return done
elif 'walker_' in self.env_id: elif 'walker_' in self.env_id:
torso_height = next_obs[:, -2] torso_height = next_obs[:, -2]
torso_ang = next_obs[:, -1] torso_ang = next_obs[:, -1]
if 'walker_7' in env_id or 'walker_5' in env_id: if 'walker_7' in self.env_id or 'walker_5' in self.env_id:
offset = 0. offset = 0.
else: else:
offset = 0.26 offset = 0.26
...@@ -57,14 +60,20 @@ class MujocoModelEnv(object): ...@@ -57,14 +60,20 @@ class MujocoModelEnv(object):
done = torch.zeros_like(next_obs.sum(-1)).bool() done = torch.zeros_like(next_obs.sum(-1)).bool()
return done return done
def rollout(self, def rollout(
env_model: nn.Module, self,
policy: 'Policy', env_model: nn.Module,
replay_buffer: 'IBuffer', policy: 'Policy', # noqa
imagine_buffer: 'IBuffer', replay_buffer: 'IBuffer', # noqa
envstep: int, imagine_buffer: 'IBuffer', # noqa
cur_learner_iter: int) -> None: envstep: int,
# This function samples from the replay_buffer, rollouts to generate new data, and push them into the imagine_buffer cur_learner_iter: int
) -> None:
"""
Overview:
This function samples from the replay_buffer, rollouts to generate new data,
and push them into the imagine_buffer
"""
# set rollout length # set rollout length
rollout_length = self._set_rollout_length(envstep) rollout_length = self._set_rollout_length(envstep)
# load data # load data
...@@ -82,9 +91,7 @@ class MujocoModelEnv(object): ...@@ -82,9 +91,7 @@ class MujocoModelEnv(object):
timesteps = self.step(obs, actions, env_model) timesteps = self.step(obs, actions, env_model)
obs_new = {} obs_new = {}
for id, timestep in timesteps.items(): for id, timestep in timesteps.items():
transition = policy.process_transition( transition = policy.process_transition(obs[id], policy_output[id], timestep)
obs[id], policy_output[id], timestep
)
transition['collect_iter'] = cur_learner_iter transition['collect_iter'] = cur_learner_iter
buffer[id].append(transition) buffer[id].append(transition)
if not timestep.done: if not timestep.done:
...@@ -102,9 +109,12 @@ class MujocoModelEnv(object): ...@@ -102,9 +109,12 @@ class MujocoModelEnv(object):
def step(self, obs: Dict, act: Dict, env_model: nn.Module) -> Dict: def step(self, obs: Dict, act: Dict, env_model: nn.Module) -> Dict:
# This function has the same input and output format as env manager's step # This function has the same input and output format as env manager's step
data_id = list(obs.keys()) data_id = list(obs.keys())
obs = torch.stack([obs[id] for id in data_id],dim=0) obs = torch.stack([obs[id] for id in data_id], dim=0)
act = torch.stack([act[id] for id in data_id],dim=0) act = torch.stack([act[id] for id in data_id], dim=0)
rewards, next_obs = env_model.predict(obs, act) rewards, next_obs = env_model.predict(obs, act)
terminals = self.termination_fn(next_obs) terminals = self.termination_fn(next_obs)
timesteps = {id:BaseEnvTimestep(n,r,d,{}) for id, n, r, d in zip(data_id, next_obs.numpy(), rewards.numpy(), terminals.numpy())} timesteps = {
id: BaseEnvTimestep(n, r, d, {})
for id, n, r, d in zip(data_id, next_obs.numpy(), rewards.numpy(), terminals.numpy())
}
return timesteps return timesteps
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册