未验证 提交 ffe8d7c0 编写于 作者: W Weiyuhong-1998 提交者: GitHub

feature(wyh): add guided cost algorithm (#57)

* guided_cost

* max_e

* guided_cost

* fix(wyh):fix guided cost recompute bug

* fix(wyh):add model save

* feature(wyh):polish guided cost

* feature(wyh):on guided cost

* fix(wyh):gcl-modify

* fix(wyh):gcl sac config

* fix(wyh):gcl style

* fix(wyh):modify comments

* fix(wyh):masac_5m6m best config

* fix(wyh):sac bug

* fix(wyh):GCL readme

* fix(wyh):GCL readme conflicts
上级 cf8ad134
......@@ -123,13 +123,14 @@ ding -m serial -e cartpole -p dqn -s 0
| 23 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
| 24 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 25 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 26 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 27 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 28 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 29 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 30 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 31 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 32 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 26 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py
| 27 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 28 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 29 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 30 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 31 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 32 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 33 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms (1-16)
......
......@@ -11,3 +11,4 @@ from .serial_entry_r2d3 import serial_pipeline_r2d3
from .serial_entry_sqil import serial_pipeline_sqil
from .parallel_entry import parallel_pipeline
from .application_entry import eval, collect_demo_data
from .serial_entry_guided_cost import serial_pipeline_guided_cost
\ No newline at end of file
from ding.policy.base_policy import Policy
from typing import Union, Optional, List, Any, Tuple
import os
import torch
import logging
from functools import partial
from tensorboardX import SummaryWriter
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.utils import set_pkg_seed, save_file
from ding.reward_model import create_reward_model
import copy
import os
def serial_pipeline_guided_cost(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
expert_model: Optional[torch.nn.Module] = None,
max_iterations: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline guided cost: we create this serial pipeline in order to\
implement guided cost learning in DI-engine. For now, we support the following envs\
Cartpole, Lunarlander, Hopper, Halfcheetah, Walker2d. The demonstration\
data come from the expert model. We use a well-trained model to \
generate demonstration data online
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\
The default model is DQN(**cfg.policy.model)
- max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \
when reaching this iteration.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
expert_collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
expert_collector_env.seed(cfg.seed)
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
expert_policy = create_policy(cfg.policy, model=expert_model, enable_field=['learn', 'collect'])
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
expert_policy.collect_mode.load_state_dict(
torch.load(cfg.policy.collect.demonstration_info_path, map_location='cpu')
)
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
expert_collector = create_serial_collector(
cfg.policy.collect.collector,
env=expert_collector_env,
policy=expert_policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
expert_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
train_data = copy.deepcopy(new_data)
expert_data = expert_collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
expert_buffer.push(expert_data, cur_collector_envstep=expert_collector.envstep)
# Learn policy from collected data
for i in range(cfg.reward_model.update_per_collect):
expert_demo = expert_buffer.sample(cfg.reward_model.batch_size, learner.train_iter)
samp = replay_buffer.sample(cfg.reward_model.batch_size, learner.train_iter)
reward_model.train(expert_demo, samp, learner.train_iter, collector.envstep)
for i in range(cfg.policy.learn.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
#train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
train_data = train_data
reward_model.estimate(train_data)
if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
"Replay buffer's data can only train for {} steps. ".format(i) +
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
learner.train(train_data, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
if cfg.policy.on_policy:
# On-policy algorithm must clear the replay buffer.
replay_buffer.clear()
dirname = cfg.exp_name + '/reward_model'
if not os.path.exists(dirname):
try:
os.mkdir(dirname)
except FileExistsError:
pass
if learner.train_iter % cfg.reward_model.store_model_every_n_train == 0:
#if learner.train_iter%5000 == 0:
path = os.path.join(dirname, 'iteration_{}.pth.tar'.format(learner.train_iter))
state_dict = reward_model.state_dict_reward_model()
save_file(path, state_dict)
path = os.path.join(dirname, 'final_model.pth.tar')
state_dict = reward_model.state_dict_reward_model()
save_file(path, state_dict)
# Learner's after_run hook.
learner.call_hook('after_run')
return policy
......@@ -159,6 +159,10 @@ class SACPolicy(Policy):
init_w=3e-3,
),
collect=dict(
# If you need the data collected by the collector to contain logit key which reflect the probability of the action, you can change the key to be True.
# In Guided cost Learning, we need to use logit to train the reward model, we change the key to be True.
# Default collector_logit to False.
collector_logit=False,
# You can use either "n_sample" or "n_episode" in actor.collect.
# Get "n_sample" samples per collect.
# Default n_sample to 1.
......@@ -481,13 +485,23 @@ class SACPolicy(Policy):
Return:
- transition (:obj:`Dict[str, Any]`): Dict type transition data.
"""
transition = {
'obs': obs,
'next_obs': timestep.obs,
'action': policy_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
if self._cfg.collect.collector_logit:
transition = {
'obs': obs,
'next_obs': timestep.obs,
'logit': policy_output['logit'],
'action': policy_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
else:
transition = {
'obs': obs,
'next_obs': timestep.obs,
'action': policy_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
return transition
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
......
......@@ -8,4 +8,5 @@ from .red_irl_model import RedRewardModel
from .her_reward_model import HerRewardModel
# exploration
from .rnd_reward_model import RndRewardModel
from .guided_cost_reward_model import GuidedCostRewardModel
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
from typing import List, Dict, Any, Tuple, Union, Optional
from easydict import EasyDict
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Independent, Normal
from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
from ding.utils.data import default_collate, default_decollate
from ding.model import FCEncoder, ConvEncoder
from .base_reward_model import BaseRewardModel
class GuidedCostNN(nn.Module):
def __init__(
self,
input_size,
hidden_size=128,
output_size=1,
):
super(GuidedCostNN, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size),
)
def forward(self, x):
return self.net(x)
@REWARD_MODEL_REGISTRY.register('guided_cost')
class GuidedCostRewardModel(BaseRewardModel):
r"""
Overview:
Policy class of Guided cost algorithm.
https://arxiv.org/pdf/1603.00448.pdf
"""
config = dict(
type='guided_cost',
learning_rate=1e-3,
action_shape=1,
continuous=True,
batch_size=64,
hidden_size=128,
update_per_collect=100,
log_every_n_train=50,
store_model_every_n_train=100,
)
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
super(GuidedCostRewardModel, self).__init__()
self.cfg = config
self.action_shape = self.cfg.action_shape
assert device == "cpu" or device.startswith("cuda")
self.device = device
self.tb_logger = tb_logger
self.reward_model = GuidedCostNN(config.input_size, config.hidden_size)
self.reward_model.to(self.device)
self.opt = optim.Adam(self.reward_model.parameters(), lr=config.learning_rate)
def train(self, expert_demo: torch.Tensor, samp: torch.Tensor, iter, step):
device_0 = expert_demo[0]['obs'].device
device_1 = samp[0]['obs'].device
for i in range(len(expert_demo)):
expert_demo[i]['prob'] = torch.FloatTensor([1]).to(device_0)
if self.cfg.continuous:
for i in range(len(samp)):
(mu, sigma) = samp[i]['logit']
dist = Independent(Normal(mu, sigma), 1)
next_action = samp[i]['action']
log_prob = dist.log_prob(next_action)
samp[i]['prob'] = torch.exp(log_prob).unsqueeze(0).to(device_1)
else:
for i in range(len(samp)):
probs = F.softmax(samp[i]['logit'], dim=-1)
prob = probs[samp[i]['action']]
samp[i]['prob'] = prob.to(device_1)
# Mix the expert data and sample data to train the reward model.
samp.extend(expert_demo)
expert_demo = default_collate(expert_demo)
samp = default_collate(samp)
cost_demo = self.reward_model(
torch.cat([expert_demo['obs'], expert_demo['action'].float().reshape(-1, self.action_shape)], dim=-1)
)
cost_samp = self.reward_model(
torch.cat([samp['obs'], samp['action'].float().reshape(-1, self.action_shape)], dim=-1)
)
prob = samp['prob'].unsqueeze(-1)
loss_IOC = torch.mean(cost_demo) + \
torch.log(torch.mean(torch.exp(-cost_samp)/(prob+1e-7)))
# UPDATING THE COST FUNCTION
self.opt.zero_grad()
loss_IOC.backward()
self.opt.step()
if iter % self.cfg.log_every_n_train == 0:
self.tb_logger.add_scalar('reward_model/loss_iter', loss_IOC, iter)
self.tb_logger.add_scalar('reward_model/loss_step', loss_IOC, step)
def estimate(self, data: list) -> None:
for i in range(len(data)):
with torch.no_grad():
reward = self.reward_model(torch.cat([data[i]['obs'],
data[i]['action'].float()]).unsqueeze(0)).squeeze(0)
data[i]['reward'] = -reward
def collect_data(self, data) -> None:
"""
Overview:
Collecting training data, not implemented if reward model (i.e. online_net) is only trained ones, \
if online_net is trained continuously, there should be some implementations in collect_data method
"""
# if online_net is trained continuously, there should be some implementations in collect_data method
pass
def clear_data(self):
"""
Overview:
Collecting clearing data, not implemented if reward model (i.e. online_net) is only trained ones, \
if online_net is trained continuously, there should be some implementations in clear_data method
"""
# if online_net is trained continuously, there should be some implementations in clear_data method
pass
def state_dict_reward_model(self) -> Dict[str, Any]:
return {
'model': self.reward_model.state_dict(),
'optimizer': self.opt.state_dict(),
}
def load_state_dict_reward_model(self, state_dict: Dict[str, Any]) -> None:
self.reward_model.load_state_dict(state_dict['model'])
self.opt.load_state_dict(state_dict['optimizer'])
from easydict import EasyDict
from ding.entry import serial_pipeline_guided_cost
lunarlander_ppo_config = dict(
exp_name='lunarlander_guided_cost',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
reward_model=dict(
learning_rate=0.001,
input_size=9,
batch_size=32,
continuous=False,
update_per_collect=20,
),
policy=dict(
cuda=False,
continuous=False,
recompute_adv=True,
model=dict(
obs_shape=8,
action_shape=4,
),
learn=dict(
update_per_collect=8,
batch_size=800,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
adv_norm=True,
),
collect=dict(
demonstration_info_path='path',
n_sample=800,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
),
)
lunarlander_ppo_config = EasyDict(lunarlander_ppo_config)
main_config = lunarlander_ppo_config
lunarlander_ppo_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
reward_model=dict(type='guided_cost'),
)
lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config)
create_config = lunarlander_ppo_create_config
if __name__ == "__main__":
serial_pipeline_guided_cost([main_config, create_config], seed=0)
from easydict import EasyDict
from ding.entry import serial_pipeline_guided_cost
cartpole_ppo_offpolicy_config = dict(
exp_name='cartpole_guided_cost',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
reward_model=dict(
learning_rate=0.001,
input_size=5,
batch_size=32,
continuous=False,
update_per_collect=10,
),
policy=dict(
cuda=False,
continuous=False,
recompute_adv=True,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
update_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
),
collect=dict(
demonstration_info_path='path',
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
eval=dict(
evaluator=dict(
eval_freq=50,
cfg_type='InteractionSerialEvaluatorDict',
stop_value=195,
n_episode=5,
),
),
),
)
cartpole_ppo_offpolicy_config = EasyDict(cartpole_ppo_offpolicy_config)
main_config = cartpole_ppo_offpolicy_config
cartpole_ppo_offpolicy_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
reward_model=dict(type='guided_cost'),
)
cartpole_ppo_offpolicy_create_config = EasyDict(cartpole_ppo_offpolicy_create_config)
create_config = cartpole_ppo_offpolicy_create_config
if __name__ == "__main__":
serial_pipeline_guided_cost([main_config, create_config], seed=0)
from easydict import EasyDict
from ding.entry import serial_pipeline_guided_cost
halfcheetah_gcl_default_config = dict(
env=dict(
env_id='HalfCheetah-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=12000,
),
reward_model=dict(
learning_rate=0.001,
input_size=23,
batch_size=32,
action_shape=6,
continuous=True,
update_per_collect=20,
),
policy=dict(
cuda=False,
on_policy=False,
random_collect_size=0,
model=dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
learn=dict(
update_per_collect=1,
batch_size=256,
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
learning_rate_alpha=3e-4,
ignore_done=True,
target_theta=0.005,
discount_factor=0.99,
alpha=0.2,
reparameterization=True,
auto_alpha=False,
),
collect=dict(
demonstration_info_path='path',
collector_logit=True,
n_sample=256,
unroll_len=1,
),
command=dict(),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
halfcheetah_gcl_default_config = EasyDict(halfcheetah_gcl_default_config)
main_config = halfcheetah_gcl_default_config
halfcheetah_gcl_default_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='sac',
import_names=['ding.policy.sac'],
),
replay_buffer=dict(type='naive', ),
reward_model=dict(type='guided_cost'),
)
halfcheetah_gcl_default_create_config = EasyDict(halfcheetah_gcl_default_create_config)
create_config = halfcheetah_gcl_default_create_config
if __name__ == '__main__':
serial_pipeline_guided_cost((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline_guided_cost
from easydict import EasyDict
hopper_gcl_default_config = dict(
exp_name='hopper_guided_cost',
env=dict(
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=4,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=3000,
),
reward_model=dict(
learning_rate=0.001,
input_size=14,
batch_size=32,
action_shape=3,
continuous=True,
update_per_collect=20,
),
policy=dict(
cuda=False,
recompute_adv=True,
model=dict(
obs_shape=11,
action_shape=3,
continuous=True,
),
continuous=True,
learn=dict(
update_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
),
collect=dict(
demonstration_info_path='path',
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=100, )),
),
)
hopper_gcl_default_config = EasyDict(hopper_gcl_default_config)
main_config = hopper_gcl_default_config
hopper_gcl_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='ppo',
import_names=['ding.policy.ppo'],
),
reward_model=dict(type='guided_cost'),
)
hopper_gcl_create_default_config = EasyDict(hopper_gcl_create_default_config)
create_config = hopper_gcl_create_default_config
if __name__ == '__main__':
serial_pipeline_guided_cost((main_config, create_config), seed=0)
from copy import deepcopy
from ding.entry import serial_pipeline_guided_cost
from easydict import EasyDict
walker_gcl_default_config = dict(
env=dict(
env_id='Walker2d-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=3000,
),
reward_model=dict(
learning_rate=0.001,
input_size=23,
batch_size=32,
action_shape=6,
continuous=True,
update_per_collect=20,
),
policy=dict(
cuda=False,
recompute_adv=True,
model=dict(
obs_shape=17,
action_shape=6,
continuous=True,
),
continuous=True,
learn=dict(
update_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
),
collect=dict(
demonstration_info_path='path',
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=100, )),
),
)
walker_gcl_default_config = EasyDict(walker_gcl_default_config)
main_config = walker_gcl_default_config
walker_gcl_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='base'),
policy=dict(
type='ppo',
import_names=['ding.policy.ppo'],
),
replay_buffer=dict(type='naive', ),
reward_model=dict(type='guided_cost'),
)
walker_gcl_create_default_config = EasyDict(walker_gcl_create_default_config)
create_config = walker_gcl_create_default_config
if __name__ == '__main__':
serial_pipeline_guided_cost((main_config, create_config), seed=0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册