提交 234de26b 编写于 作者: N niuyazhe

fix(nyz): fix trex unittest bugs

上级 63105fef
......@@ -144,8 +144,10 @@ def collect_demo_data(
policy.collect_mode.load_state_dict(state_dict)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
else {'eps': cfg.policy.other.eps.get('collect', 0.2)}
if hasattr(cfg.policy.other, 'eps'):
policy_kwargs = {'eps': 0.}
else:
policy_kwargs = None
# Let's collect some expert demonstrations
exp_data = collector.collect(n_sample=collect_count, policy_kwargs=policy_kwargs)
......@@ -215,8 +217,10 @@ def collect_episodic_demo_data(
policy.collect_mode.load_state_dict(state_dict)
collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
else {'eps': cfg.policy.other.eps.get('collect', 0.2)}
if hasattr(cfg.policy.other, 'eps'):
policy_kwargs = {'eps': 0.}
else:
policy_kwargs = None
# Let's collect some expert demostrations
exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs)
......
......@@ -17,7 +17,6 @@ from ding.entry import serial_pipeline
@pytest.mark.unittest
def test_collect_episodic_demo_data_for_trex():
expert_policy_state_dict_path = './expert_policy.pth'
expert_policy_state_dict_path = os.path.abspath('ding/entry/expert_policy.pth')
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
expert_policy = serial_pipeline(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
......@@ -40,7 +39,7 @@ def test_collect_episodic_demo_data_for_trex():
os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
@pytest.mark.unittest
# @pytest.mark.unittest
def test_trex_collecting_data():
expert_policy_state_dict_path = './cartpole_ppo_offpolicy'
expert_policy_state_dict_path = os.path.abspath(expert_policy_state_dict_path)
......@@ -55,10 +54,9 @@ def test_trex_collecting_data():
'device': 'cpu'
}
)
args.cfg[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo'
args.cfg[0].reward_model.offline_data_path = 'cartpole_trex_offppo_offline_data'
args.cfg[0].reward_model.offline_data_path = os.path.abspath(args.cfg[0].reward_model.offline_data_path)
args.cfg[0].reward_model.reward_model_path = args.cfg[0].reward_model.offline_data_path + '/cartpole.params'
args.cfg[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy'
args.cfg[0].reward_model.expert_model_path = os.path.abspath(args.cfg[0].reward_model.expert_model_path)
trex_collecting_data(args=args)
os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
......
......@@ -86,7 +86,7 @@ class DQNILPolicy(ILPolicy):
self._collect_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._collect_model.reset()
def _forward_collect(self, data: dict):
def _forward_collect(self, data: dict, eps: float):
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
......
......@@ -12,7 +12,7 @@ from dizoo.mujoco.config import hopper_trex_ppo_default_config, hopper_trex_ppo_
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
@pytest.mark.unittest
# @pytest.mark.unittest
def test_serial_pipeline_reward_model_trex():
config = [deepcopy(hopper_ppo_default_config), deepcopy(hopper_ppo_create_default_config)]
expert_policy = serial_pipeline_onpolicy(config, seed=0, max_iterations=90)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册