提交 b0877289 编写于 作者: G gaotingquan 提交者: Tingquan Gao

disable promote kernel for amp training

compatible with paddle 2.5 and older version.
ref: https://github.com/PaddlePaddle/PaddleClas/pull/2798
上级 162f013e
...@@ -13,7 +13,18 @@ class AutoCast: ...@@ -13,7 +13,18 @@ class AutoCast:
self.amp_eval = amp_eval self.amp_eval = amp_eval
if self.use_amp: if self.use_amp:
self.cast_context = partial(paddle.amp.auto_cast, level=amp_level) # compatible with paddle 2.5 and older version
paddle_version = paddle.__version__[:3]
# paddle version >= 2.5.0 or develop
if paddle_version in ["2.5", "0.0"]:
self.cast_context = partial(
paddle.amp.auto_cast,
level=amp_level,
use_promote=use_promote)
# paddle version < 2.3.0 and not develop
else:
self.cast_context = partial(
paddle.amp.auto_cast, level=amp_level)
def __call__(self, is_eval=False): def __call__(self, is_eval=False):
if self.use_amp: if self.use_amp:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册