diff --git a/PaddleCV/PaddleDetection/tools/train.py b/PaddleCV/PaddleDetection/tools/train.py index c688fef4ec91687918b7399c72298df4ec5a5df6..f45462cc90cca903f089adaf21332a67320a6dc4 100644 --- a/PaddleCV/PaddleDetection/tools/train.py +++ b/PaddleCV/PaddleDetection/tools/train.py @@ -264,7 +264,7 @@ def main(): profiler.start_profiler("All") elif FLAGS.is_profiler and it == 10: profiler.stop_profiler("total", FLAGS.profiler_path) - return + return if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \ and (not FLAGS.dist or trainer_id == 0):