提交 414b5305 编写于 作者: N niuyazhe

polish(nyz): fix repeat eval at beginning

上级 f962ef01
......@@ -155,6 +155,8 @@ class BaseSerialEvaluator(object):
Determine whether you need to start the evaluation mode, if the number of training has reached\
the maximum number of times to start the evaluator, return True
"""
if train_iter == self._last_eval_iter:
return False
if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
return False
self._last_eval_iter = train_iter
......
......@@ -83,16 +83,17 @@ def main(cfg, seed=0, max_iterations=int(1e8)):
else:
sample_size = learner.policy.get_attribute('batch_size')
train_episode = replay_buffer.sample(sample_size, learner.train_iter)
if train_episode is not None:
train_data = []
if her_cfg is not None:
her_episodes = []
for e in train_episode:
her_episodes.extend(her_model.estimate(e))
train_episode.extend(her_episodes)
if train_episode is None:
break
train_data = []
if her_cfg is not None:
her_episodes = []
for e in train_episode:
train_data.extend(policy.collect_mode.get_train_sample(e))
learner.train(train_data, collector.envstep)
her_episodes.extend(her_model.estimate(e))
train_episode.extend(her_episodes)
for e in train_episode:
train_data.extend(policy.collect_mode.get_train_sample(e))
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
......
......@@ -71,8 +71,9 @@ def main(cfg, seed=0):
# Training
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None:
learner.train(train_data, collector.envstep)
if train_data is None:
break
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
......
......@@ -63,11 +63,12 @@ def main(cfg, seed=0):
# Collect data from environments
new_data = collector.collect(train_iter=learner.train_iter)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Trian
# Train
for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None:
learner.train(train_data, collector.envstep)
if train_data is None:
break
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册