未验证 提交 4a9de931 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the bug in fleet amp_init. (#30606)

* Fix the bug in fleet amp_init.

* Fix the amp_init unit test.
上级 7e9f336b
...@@ -1019,8 +1019,8 @@ class Fleet(object): ...@@ -1019,8 +1019,8 @@ class Fleet(object):
run_example_code() run_example_code()
""" """
# imitate target optimizer retrieval # imitate target optimizer retrieval
return self.user_defined_optimizer.amp_init( return self.user_defined_optimizer.amp_init(place, scope, test_program,
place, scope=None, test_program=None, use_fp16_test=False) use_fp16_test)
def _final_strategy(self): def _final_strategy(self):
if "valid_strategy" not in self._context: if "valid_strategy" not in self._context:
......
...@@ -67,7 +67,7 @@ class TestFleetAMPInit(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestFleetAMPInit(unittest.TestCase):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
optimizer.amp_init(place, use_fp16_test=True) optimizer.amp_init(place)
step = 1 step = 1
for i in range(step): for i in range(step):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册