提交 14d06fb6 编写于 作者: G gaotingquan 提交者: Tingquan Gao

support AMP.use_amp arg

上级 b0877289
...@@ -525,7 +525,8 @@ class Engine(object): ...@@ -525,7 +525,8 @@ class Engine(object):
def _init_amp(self): def _init_amp(self):
amp_config = self.config.get("AMP", None) amp_config = self.config.get("AMP", None)
use_amp = True if amp_config else False use_amp = True if amp_config and amp_config.get("use_amp",
True) else False
if not use_amp: if not use_amp:
self.auto_cast = AutoCast(use_amp) self.auto_cast = AutoCast(use_amp)
......
...@@ -21,7 +21,7 @@ class AutoCast: ...@@ -21,7 +21,7 @@ class AutoCast:
paddle.amp.auto_cast, paddle.amp.auto_cast,
level=amp_level, level=amp_level,
use_promote=use_promote) use_promote=use_promote)
# paddle version < 2.3.0 and not develop # paddle version <= 2.4.x and not develop
else: else:
self.cast_context = partial( self.cast_context = partial(
paddle.amp.auto_cast, level=amp_level) paddle.amp.auto_cast, level=amp_level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册