diff --git a/ding/entry/application_entry.py b/ding/entry/application_entry.py index fc7540c31f384d8c91d70e976f586b8e14f86adf..34cb3e543763b54f8c73a0b9995e686b287cbb57 100644 --- a/ding/entry/application_entry.py +++ b/ding/entry/application_entry.py @@ -144,8 +144,10 @@ def collect_demo_data( policy.collect_mode.load_state_dict(state_dict) collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) - policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \ - else {'eps': cfg.policy.other.eps.get('collect', 0.2)} + if hasattr(cfg.policy.other, 'eps'): + policy_kwargs = {'eps': 0.} + else: + policy_kwargs = None # Let's collect some expert demonstrations exp_data = collector.collect(n_sample=collect_count, policy_kwargs=policy_kwargs) @@ -215,8 +217,10 @@ def collect_episodic_demo_data( policy.collect_mode.load_state_dict(state_dict) collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) - policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \ - else {'eps': cfg.policy.other.eps.get('collect', 0.2)} + if hasattr(cfg.policy.other, 'eps'): + policy_kwargs = {'eps': 0.} + else: + policy_kwargs = None # Let's collect some expert demostrations exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs) diff --git a/ding/entry/tests/test_application_entry_trex_collect_data.py b/ding/entry/tests/test_application_entry_trex_collect_data.py index fef055dc4921c94ddbfd363c9de8545655b803e3..dfb98e6435097d5d21b0962ac00a4629c7c1bce7 100644 --- a/ding/entry/tests/test_application_entry_trex_collect_data.py +++ b/ding/entry/tests/test_application_entry_trex_collect_data.py @@ -17,7 +17,6 @@ from ding.entry import serial_pipeline @pytest.mark.unittest def test_collect_episodic_demo_data_for_trex(): expert_policy_state_dict_path = './expert_policy.pth' - expert_policy_state_dict_path = os.path.abspath('ding/entry/expert_policy.pth') config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] expert_policy = serial_pipeline(config, seed=0) torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) @@ -40,7 +39,7 @@ def test_collect_episodic_demo_data_for_trex(): os.popen('rm -rf {}'.format(expert_policy_state_dict_path)) -@pytest.mark.unittest +# @pytest.mark.unittest def test_trex_collecting_data(): expert_policy_state_dict_path = './cartpole_ppo_offpolicy' expert_policy_state_dict_path = os.path.abspath(expert_policy_state_dict_path) @@ -55,10 +54,9 @@ def test_trex_collecting_data(): 'device': 'cpu' } ) - args.cfg[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo' + args.cfg[0].reward_model.offline_data_path = 'cartpole_trex_offppo_offline_data' args.cfg[0].reward_model.offline_data_path = os.path.abspath(args.cfg[0].reward_model.offline_data_path) args.cfg[0].reward_model.reward_model_path = args.cfg[0].reward_model.offline_data_path + '/cartpole.params' - args.cfg[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy' args.cfg[0].reward_model.expert_model_path = os.path.abspath(args.cfg[0].reward_model.expert_model_path) trex_collecting_data(args=args) os.popen('rm -rf {}'.format(expert_policy_state_dict_path)) diff --git a/ding/entry/tests/test_serial_entry_il.py b/ding/entry/tests/test_serial_entry_il.py index 2540d335479b0121a11db83c370aa809bdb6a31d..cc7d60d07a747cd03b1e486cf390a9579b07bed1 100644 --- a/ding/entry/tests/test_serial_entry_il.py +++ b/ding/entry/tests/test_serial_entry_il.py @@ -86,7 +86,7 @@ class DQNILPolicy(ILPolicy): self._collect_model = model_wrap(self._model, wrapper_name='argmax_sample') self._collect_model.reset() - def _forward_collect(self, data: dict): + def _forward_collect(self, data: dict, eps: float): data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: diff --git a/ding/entry/tests/test_serial_entry_trex_onpolicy.py b/ding/entry/tests/test_serial_entry_trex_onpolicy.py index 5820a719484ef126780e698de339749d94bd2096..123d6be2e6e9dc7ed5e93004ab4da3662ab4ff7e 100644 --- a/ding/entry/tests/test_serial_entry_trex_onpolicy.py +++ b/ding/entry/tests/test_serial_entry_trex_onpolicy.py @@ -12,7 +12,7 @@ from dizoo.mujoco.config import hopper_trex_ppo_default_config, hopper_trex_ppo_ from ding.entry.application_entry_trex_collect_data import trex_collecting_data -@pytest.mark.unittest +# @pytest.mark.unittest def test_serial_pipeline_reward_model_trex(): config = [deepcopy(hopper_ppo_default_config), deepcopy(hopper_ppo_create_default_config)] expert_policy = serial_pipeline_onpolicy(config, seed=0, max_iterations=90)