diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 3a631edb92128eab9f050f518b01ecdd825f4209..0e4559e6bc624b99e1d9f72b629e2e9d9a499bd0 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1019,8 +1019,8 @@ class Fleet(object): run_example_code() """ # imitate target optimizer retrieval - return self.user_defined_optimizer.amp_init( - place, scope=None, test_program=None, use_fp16_test=False) + return self.user_defined_optimizer.amp_init(place, scope, test_program, + use_fp16_test) def _final_strategy(self): if "valid_strategy" not in self._context: diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py index d7da4ead1b0ed47b81409efc161c55ce835525ee..2fa6bf54769e0fa0bd8b0c97a40bd523c623bee6 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py @@ -67,7 +67,7 @@ class TestFleetAMPInit(unittest.TestCase): exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - optimizer.amp_init(place, use_fp16_test=True) + optimizer.amp_init(place) step = 1 for i in range(step):