提交 f70d3ddb 编写于 作者: N niuyazhe

fix(nyz): fix wqmix target_model state_dict bug and polish mujoco model env

上级 db642fd3
......@@ -286,3 +286,31 @@ class WQMIXPolicy(QMIXPolicy):
by import_names path. For WQMIX, ``ding.model.template.wqmix``
"""
return 'wqmix', ['ding.model.template.wqmix']
def _state_dict_learn(self) -> Dict[str, Any]:
r"""
Overview:
Return the state_dict of learn mode, usually including model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
"""
return {
'model': self._learn_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
'optimizer_star': self._optimizer_star.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
r"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
complicated operation.
"""
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
self._optimizer_star.load_state_dict(state_dict['optimizer_star'])
......@@ -48,9 +48,7 @@ hopper_ppo_create_default_config = dict(
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='ppo',
),
policy=dict(type='ppo', ),
)
hopper_ppo_create_default_config = EasyDict(hopper_ppo_create_default_config)
create_config = hopper_ppo_create_default_config
......@@ -37,7 +37,7 @@ halfcheetah_td3_default_config = dict(
min=-0.5,
max=0.5,
),
learner = dict(
learner=dict(
load_path='./td3/ckpt/ckpt_best.pth.tar',
hook=dict(
load_ckpt_before_run='./td3/ckpt/ckpt_best.pth.tar',
......
......@@ -22,7 +22,7 @@ y0 = rollout_length_min
y1 = rollout_length_max
w = (rollout_length_max - rollout_length_min) / (rollout_end_step - rollout_start_step)
b = rollout_length_min
set_rollout_length = lambda x: int( min( max( w * (x - x0) + b, y0 ), y1 ) )
set_rollout_length = lambda x: int(min(max(w * (x - x0) + b, y0), y1))
set_buffer_size = lambda x: set_rollout_length(x) * rollout_batch_size * rollout_retain
main_config = dict(
......
......@@ -22,7 +22,7 @@ y0 = rollout_length_min
y1 = rollout_length_max
w = (rollout_length_max - rollout_length_min) / (rollout_end_step - rollout_start_step)
b = rollout_length_min
set_rollout_length = lambda x: int( min( max( w * (x - x0) + b, y0 ), y1 ) )
set_rollout_length = lambda x: int(min(max(w * (x - x0) + b, y0), y1))
set_buffer_size = lambda x: set_rollout_length(x) * rollout_batch_size * rollout_retain
main_config = dict(
......
......@@ -23,6 +23,7 @@ def eval_ckpt(args):
eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
# eval(config, seed=args.seed, state_dict=state_dict)
def generate(args):
main_config.exp_name = 'td3'
main_config.policy.learn.learner.load_path = './td3/ckpt/ckpt_best.pth.tar'
......@@ -30,8 +31,14 @@ def generate(args):
main_config.policy.collect.data_type = 'hdf5'
config = deepcopy([main_config, create_config])
state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
collect_demo_data(config, collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
seed=args.seed, expert_data_path=main_config.policy.collect.save_path, state_dict=state_dict)
collect_demo_data(
config,
collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
seed=args.seed,
expert_data_path=main_config.policy.collect.save_path,
state_dict=state_dict
)
def train_expert(args):
from dizoo.mujoco.config.hopper_td3_default_config import main_config, create_config
......
......@@ -313,8 +313,8 @@ class MujocoEnv(BaseEnv):
info.rew_space.shape = rew_shape
return info
else:
raise NotImplementedError('{} not found in MUJOCO_INFO_DICT [{}]'\
.format(self._cfg.env_id, MUJOCO_INFO_DICT.keys()))
keys = MUJOCO_INFO_DICT.keys()
raise NotImplementedError('{} not found in MUJOCO_INFO_DICT [{}]'.format(self._cfg.env_id, keys))
def _make_env(self, only_info=False):
return wrap_mujoco(
......
from typing import Any, Union, List, Callable, Dict
import copy
import torch
import torch.nn as nn
import numpy as np
from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
......@@ -11,8 +12,10 @@ from .mujoco_wrappers import wrap_mujoco
from ding.utils import ENV_REGISTRY
from ding.worker.collector.base_serial_collector import to_tensor_transitions
@ENV_REGISTRY.register('mujoco_model')
class MujocoModelEnv(object):
def __init__(self, env_id: str, set_rollout_length: Callable, rollout_batch_size: int = 100000):
self.env_id = env_id
self.rollout_batch_size = rollout_batch_size
......@@ -41,9 +44,9 @@ class MujocoModelEnv(object):
done = ~not_done
return done
elif 'walker_' in self.env_id:
torso_height = next_obs[:, -2]
torso_height = next_obs[:, -2]
torso_ang = next_obs[:, -1]
if 'walker_7' in env_id or 'walker_5' in env_id:
if 'walker_7' in self.env_id or 'walker_5' in self.env_id:
offset = 0.
else:
offset = 0.26
......@@ -57,14 +60,20 @@ class MujocoModelEnv(object):
done = torch.zeros_like(next_obs.sum(-1)).bool()
return done
def rollout(self,
env_model: nn.Module,
policy: 'Policy',
replay_buffer: 'IBuffer',
imagine_buffer: 'IBuffer',
envstep: int,
cur_learner_iter: int) -> None:
# This function samples from the replay_buffer, rollouts to generate new data, and push them into the imagine_buffer
def rollout(
self,
env_model: nn.Module,
policy: 'Policy', # noqa
replay_buffer: 'IBuffer', # noqa
imagine_buffer: 'IBuffer', # noqa
envstep: int,
cur_learner_iter: int
) -> None:
"""
Overview:
This function samples from the replay_buffer, rollouts to generate new data,
and push them into the imagine_buffer
"""
# set rollout length
rollout_length = self._set_rollout_length(envstep)
# load data
......@@ -82,9 +91,7 @@ class MujocoModelEnv(object):
timesteps = self.step(obs, actions, env_model)
obs_new = {}
for id, timestep in timesteps.items():
transition = policy.process_transition(
obs[id], policy_output[id], timestep
)
transition = policy.process_transition(obs[id], policy_output[id], timestep)
transition['collect_iter'] = cur_learner_iter
buffer[id].append(transition)
if not timestep.done:
......@@ -102,9 +109,12 @@ class MujocoModelEnv(object):
def step(self, obs: Dict, act: Dict, env_model: nn.Module) -> Dict:
# This function has the same input and output format as env manager's step
data_id = list(obs.keys())
obs = torch.stack([obs[id] for id in data_id],dim=0)
act = torch.stack([act[id] for id in data_id],dim=0)
obs = torch.stack([obs[id] for id in data_id], dim=0)
act = torch.stack([act[id] for id in data_id], dim=0)
rewards, next_obs = env_model.predict(obs, act)
terminals = self.termination_fn(next_obs)
timesteps = {id:BaseEnvTimestep(n,r,d,{}) for id, n, r, d in zip(data_id, next_obs.numpy(), rewards.numpy(), terminals.numpy())}
timesteps = {
id: BaseEnvTimestep(n, r, d, {})
for id, n, r, d in zip(data_id, next_obs.numpy(), rewards.numpy(), terminals.numpy())
}
return timesteps
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册