提交 414b5305 编写于 作者: N niuyazhe

polish(nyz): fix repeat eval at beginning

上级 f962ef01
...@@ -155,6 +155,8 @@ class BaseSerialEvaluator(object): ...@@ -155,6 +155,8 @@ class BaseSerialEvaluator(object):
Determine whether you need to start the evaluation mode, if the number of training has reached\ Determine whether you need to start the evaluation mode, if the number of training has reached\
the maximum number of times to start the evaluator, return True the maximum number of times to start the evaluator, return True
""" """
if train_iter == self._last_eval_iter:
return False
if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0: if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
return False return False
self._last_eval_iter = train_iter self._last_eval_iter = train_iter
......
...@@ -83,16 +83,17 @@ def main(cfg, seed=0, max_iterations=int(1e8)): ...@@ -83,16 +83,17 @@ def main(cfg, seed=0, max_iterations=int(1e8)):
else: else:
sample_size = learner.policy.get_attribute('batch_size') sample_size = learner.policy.get_attribute('batch_size')
train_episode = replay_buffer.sample(sample_size, learner.train_iter) train_episode = replay_buffer.sample(sample_size, learner.train_iter)
if train_episode is not None: if train_episode is None:
train_data = [] break
if her_cfg is not None: train_data = []
her_episodes = [] if her_cfg is not None:
for e in train_episode: her_episodes = []
her_episodes.extend(her_model.estimate(e))
train_episode.extend(her_episodes)
for e in train_episode: for e in train_episode:
train_data.extend(policy.collect_mode.get_train_sample(e)) her_episodes.extend(her_model.estimate(e))
learner.train(train_data, collector.envstep) train_episode.extend(her_episodes)
for e in train_episode:
train_data.extend(policy.collect_mode.get_train_sample(e))
learner.train(train_data, collector.envstep)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -71,8 +71,9 @@ def main(cfg, seed=0): ...@@ -71,8 +71,9 @@ def main(cfg, seed=0):
# Training # Training
for i in range(cfg.policy.learn.update_per_collect): for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None: if train_data is None:
learner.train(train_data, collector.envstep) break
learner.train(train_data, collector.envstep)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -63,11 +63,12 @@ def main(cfg, seed=0): ...@@ -63,11 +63,12 @@ def main(cfg, seed=0):
# Collect data from environments # Collect data from environments
new_data = collector.collect(train_iter=learner.train_iter) new_data = collector.collect(train_iter=learner.train_iter)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Trian # Train
for i in range(cfg.policy.learn.update_per_collect): for i in range(cfg.policy.learn.update_per_collect):
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is not None: if train_data is None:
learner.train(train_data, collector.envstep) break
learner.train(train_data, collector.envstep)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册