未验证 提交 5fbc9453 编写于 作者: W Will-Nie 提交者: GitHub

enable user to use any expert model for sqil(#44)

* enable user to use any model generated here

* delete irelevant package

* add test

* bash format.sh to reformat style
上级 c8dac674
......@@ -150,7 +150,8 @@ def cli(
from .serial_entry_sqil import serial_pipeline_sqil
if config is None:
config = get_predefined_config(env, policy)
serial_pipeline_sqil(config, seed, max_iterations=train_iter)
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_sqil(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'parallel':
from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
......
......@@ -12,11 +12,11 @@ from ding.worker import BaseLearner, SampleCollector, BaseSerialEvaluator, BaseS
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.utils import set_pkg_seed
from ding.model import DQN
def serial_pipeline_sqil(
input_cfg: Union[str, Tuple[dict, dict]],
expert_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
......@@ -47,23 +47,31 @@ def serial_pipeline_sqil(
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
expert_cfg, expert_create_cfg = read_config(expert_cfg)
else:
cfg, create_cfg = input_cfg
expert_cfg, expert_create_cfg = expert_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
expert_cfg = compile_config(
expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
expert_collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
expert_collector_env = create_env_manager(
expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]
)
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
expert_collector_env.seed(cfg.seed)
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
expert_policy = create_policy(cfg.policy, model=expert_model, enable_field=['collect'])
expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command'])
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
expert_policy.collect_mode.load_state_dict(
......@@ -80,20 +88,27 @@ def serial_pipeline_sqil(
exp_name=cfg.exp_name
)
expert_collector = create_serial_collector(
cfg.policy.collect.collector,
expert_cfg.policy.collect.collector,
env=expert_collector_env,
policy=expert_policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
exp_name=expert_cfg.exp_name
)
evaluator = BaseSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
expert_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
expert_buffer = create_buffer(expert_cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
expert_commander = BaseSerialCommander(
expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer,
expert_policy.command_mode
) # we create this to avoid the issue of eps, this is an issue due to the sample collector part.
expert_collect_kwargs = expert_commander.step()
if 'eps' in expert_collect_kwargs:
expert_collect_kwargs['eps'] = -1
# ==========
# Main loop
# ==========
......@@ -118,7 +133,9 @@ def serial_pipeline_sqil(
break
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
expert_data = expert_collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': -1})
expert_data = expert_collector.collect(
train_iter=learner.train_iter, policy_kwargs=expert_collect_kwargs
) # policy_kwargs={'eps': -1}
for i in range(len(new_data)):
device_1 = new_data[i]['obs'].device
device_2 = expert_data[i]['obs'].device
......
......@@ -313,7 +313,7 @@ def test_sqil():
config = [deepcopy(cartpole_sqil_config), deepcopy(cartpole_sqil_create_config)]
config[0].policy.collect.demonstration_info_path = expert_policy_state_dict_path
try:
serial_pipeline_sqil(config, seed=0)
serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
......
......@@ -18,6 +18,6 @@ def test_sqil():
config[0].policy.collect.demonstration_info_path = expert_policy_state_dict_path
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline_sqil(config, seed=0, max_iterations=1)
serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册