From 4a9de931a2dc2b4696cf02e56f12d70218ec892e Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 21 Jan 2021 10:58:59 +0800 Subject: [PATCH] Fix the bug in fleet amp_init. (#30606) * Fix the bug in fleet amp_init. * Fix the amp_init unit test. --- python/paddle/distributed/fleet/base/fleet_base.py | 4 ++-- python/paddle/fluid/tests/unittests/test_fleet_amp_init.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 3a631edb921..0e4559e6bc6 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1019,8 +1019,8 @@ class Fleet(object): run_example_code() """ # imitate target optimizer retrieval - return self.user_defined_optimizer.amp_init( - place, scope=None, test_program=None, use_fp16_test=False) + return self.user_defined_optimizer.amp_init(place, scope, test_program, + use_fp16_test) def _final_strategy(self): if "valid_strategy" not in self._context: diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py index d7da4ead1b0..2fa6bf54769 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py @@ -67,7 +67,7 @@ class TestFleetAMPInit(unittest.TestCase): exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - optimizer.amp_init(place, use_fp16_test=True) + optimizer.amp_init(place) step = 1 for i in range(step): -- GitLab