From 1691dc7a9c0a3e861a16e58c5508e3e7233be27d Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Fri, 24 Sep 2021 16:43:53 +0800 Subject: [PATCH] add update (#36017) --- .../paddle/distributed/fleet/meta_parallel/pipeline_parallel.py | 1 + python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 8fad0686dd..431bc6d7bc 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -329,6 +329,7 @@ class PipelineParallel(MetaParallelBase): def _optimizer_step(self): if self.scaler: self.scaler.step(self.optimizer) + self.scaler.update() else: self.optimizer.step() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py index 083ad31930..4c966585d5 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py @@ -48,6 +48,7 @@ class TestMPClipGrad(TestDistMPTraning): scaled.backward() # do backward scaler.step(optimizer) # update parameters + scaler.update() optimizer.clear_grad() return scaled -- GitLab