diff --git a/ding/worker/collector/base_serial_evaluator.py b/ding/worker/collector/base_serial_evaluator.py index f9d014c7e20bb9de18224baa1d4d76ee83ff681b..0038bd84a62df90a0933223dcfb7be67c180e3e1 100644 --- a/ding/worker/collector/base_serial_evaluator.py +++ b/ding/worker/collector/base_serial_evaluator.py @@ -155,6 +155,8 @@ class BaseSerialEvaluator(object): Determine whether you need to start the evaluation mode, if the number of training has reached\ the maximum number of times to start the evaluator, return True """ + if train_iter == self._last_eval_iter: + return False if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0: return False self._last_eval_iter = train_iter diff --git a/dizoo/classic_control/bitflip/entry/bitflip_dqn_main.py b/dizoo/classic_control/bitflip/entry/bitflip_dqn_main.py index fa5162e93166740216ccf16501d37a627772c856..e7935f38a027fafdd62e6986d2db7bea0789241e 100644 --- a/dizoo/classic_control/bitflip/entry/bitflip_dqn_main.py +++ b/dizoo/classic_control/bitflip/entry/bitflip_dqn_main.py @@ -83,16 +83,17 @@ def main(cfg, seed=0, max_iterations=int(1e8)): else: sample_size = learner.policy.get_attribute('batch_size') train_episode = replay_buffer.sample(sample_size, learner.train_iter) - if train_episode is not None: - train_data = [] - if her_cfg is not None: - her_episodes = [] - for e in train_episode: - her_episodes.extend(her_model.estimate(e)) - train_episode.extend(her_episodes) + if train_episode is None: + break + train_data = [] + if her_cfg is not None: + her_episodes = [] for e in train_episode: - train_data.extend(policy.collect_mode.get_train_sample(e)) - learner.train(train_data, collector.envstep) + her_episodes.extend(her_model.estimate(e)) + train_episode.extend(her_episodes) + for e in train_episode: + train_data.extend(policy.collect_mode.get_train_sample(e)) + learner.train(train_data, collector.envstep) if __name__ == "__main__": diff --git a/dizoo/classic_control/cartpole/entry/cartpole_dqn_main.py b/dizoo/classic_control/cartpole/entry/cartpole_dqn_main.py index 6f2af9074fda3c2768d264e184751316fb259e89..2a32a466506b24a7d0ff979f54b529ef060e13d5 100644 --- a/dizoo/classic_control/cartpole/entry/cartpole_dqn_main.py +++ b/dizoo/classic_control/cartpole/entry/cartpole_dqn_main.py @@ -71,8 +71,9 @@ def main(cfg, seed=0): # Training for i in range(cfg.policy.learn.update_per_collect): train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data is not None: - learner.train(train_data, collector.envstep) + if train_data is None: + break + learner.train(train_data, collector.envstep) if __name__ == "__main__": diff --git a/dizoo/classic_control/pendulum/entry/pendulum_td3_main.py b/dizoo/classic_control/pendulum/entry/pendulum_td3_main.py index 9dd937c1e1e49493caa88654ae178c1a4d44c5ff..5a9fa1814cf092bb572906d990a0cc5f9902b059 100644 --- a/dizoo/classic_control/pendulum/entry/pendulum_td3_main.py +++ b/dizoo/classic_control/pendulum/entry/pendulum_td3_main.py @@ -63,11 +63,12 @@ def main(cfg, seed=0): # Collect data from environments new_data = collector.collect(train_iter=learner.train_iter) replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - # Trian + # Train for i in range(cfg.policy.learn.update_per_collect): train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data is not None: - learner.train(train_data, collector.envstep) + if train_data is None: + break + learner.train(train_data, collector.envstep) if __name__ == "__main__":