From 282e09dcfd604f356fd3c8a63eae7d66c58dc015 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sun, 27 Feb 2022 11:59:27 +0800 Subject: [PATCH] fix pylayer problem with amp (#39950) * fix pylayer problem with amp * add ut * refine code --- python/paddle/autograd/py_layer.py | 10 +++++++ python/paddle/fluid/dygraph/amp/auto_cast.py | 13 +++++++++ .../test_imperative_auto_mixed_precision.py | 27 +++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py index 5a22d22151..26740dfd0f 100644 --- a/python/paddle/autograd/py_layer.py +++ b/python/paddle/autograd/py_layer.py @@ -14,6 +14,8 @@ import paddle from paddle.fluid.framework import dygraph_only +from paddle.fluid.dygraph.amp.auto_cast import amp_state +from paddle.amp.auto_cast import auto_cast from paddle.fluid import core __all__ = [] @@ -46,6 +48,7 @@ class PyLayerContext(object): def __init__(self): self.container = None + self._amp_state = amp_state() def save_for_backward(self, *tensors): """ @@ -178,6 +181,13 @@ class PyLayerBackward(PyLayerContext): def backward(self, *args, **kwargs): with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.no_grad(): + if self._amp_state and 'enable' in self._amp_state and self._amp_state[ + 'enable']: + with auto_cast(**args[0]._amp_state): + return self._forward_cls.backward(*args, **kwargs) + else: + + return self._forward_cls.backward(*args, **kwargs) return self._forward_cls.backward(*args, **kwargs) diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 41a7d3d774..8230e4bbd7 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -78,6 +78,13 @@ PURE_FP16_BLACK_LIST = { BF16_WHITE_LIST = {'conv2d'} BF16_BLACK_LIST = {' '} +_g_amp_state_ = None + + +def amp_state(): + global _g_amp_state_ + return _g_amp_state_ + #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. @@ -240,6 +247,11 @@ def amp_guard(enable=True, print(conv.dtype) # FP32 """ + amp_state = locals() + global _g_amp_state_ + original_state = _g_amp_state_ + _g_amp_state_ = amp_state + # check amp_level: O0-O2 level = level.upper() if not (level in ['O0', 'O1', 'O2']): @@ -349,6 +361,7 @@ def amp_guard(enable=True, yield finally: if tracer: + _g_amp_state_ = original_state tracer._amp_level = original_amp_level tracer._set_amp_op_list(original_white_list, original_black_list) # set_flags(original_flags) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 0043a7f78b..67c4bb3b2c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -20,6 +20,7 @@ import six from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting import paddle.nn as nn from paddle.static import InputSpec +from paddle.autograd import PyLayer if fluid.core.is_compiled_with_cuda(): fluid.set_flags({"FLAGS_cudnn_deterministic": True}) @@ -1146,5 +1147,31 @@ class TestBf16(unittest.TestCase): self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-1)) +class TestPyLayerWithAmp(unittest.TestCase): + def test_pylayer(self): + class MyMM(PyLayer): + @staticmethod + def forward(ctx, a, b): + ctx.save_for_backward(a, b) + return a.mm(b) + + @staticmethod + def backward(ctx, grad): + a, b = ctx.saved_tensor() + # NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast() + # thus, the mm operation raise errors because of the dtype of inputs are inconsistent + return grad.mm(b.t()), a.t().mm(grad) + + x = paddle.rand([10, 10]) + y = paddle.rand([10, 10]) + x.stop_gradient = False + y.stop_gradient = False + + with paddle.amp.auto_cast(): + res = MyMM.apply(x, y) + loss = paddle.mean(res) + loss.backward() + + if __name__ == '__main__': unittest.main() -- GitLab