未验证 提交 282e09dc 编写于 作者: L Leo Chen 提交者: GitHub

fix pylayer problem with amp (#39950)

* fix pylayer problem with amp

* add ut

* refine code
上级 b33a3c23
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import paddle import paddle
from paddle.fluid.framework import dygraph_only from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph.amp.auto_cast import amp_state
from paddle.amp.auto_cast import auto_cast
from paddle.fluid import core from paddle.fluid import core
__all__ = [] __all__ = []
...@@ -46,6 +48,7 @@ class PyLayerContext(object): ...@@ -46,6 +48,7 @@ class PyLayerContext(object):
def __init__(self): def __init__(self):
self.container = None self.container = None
self._amp_state = amp_state()
def save_for_backward(self, *tensors): def save_for_backward(self, *tensors):
""" """
...@@ -178,6 +181,13 @@ class PyLayerBackward(PyLayerContext): ...@@ -178,6 +181,13 @@ class PyLayerBackward(PyLayerContext):
def backward(self, *args, **kwargs): def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
if self._amp_state and 'enable' in self._amp_state and self._amp_state[
'enable']:
with auto_cast(**args[0]._amp_state):
return self._forward_cls.backward(*args, **kwargs)
else:
return self._forward_cls.backward(*args, **kwargs)
return self._forward_cls.backward(*args, **kwargs) return self._forward_cls.backward(*args, **kwargs)
......
...@@ -78,6 +78,13 @@ PURE_FP16_BLACK_LIST = { ...@@ -78,6 +78,13 @@ PURE_FP16_BLACK_LIST = {
BF16_WHITE_LIST = {'conv2d'} BF16_WHITE_LIST = {'conv2d'}
BF16_BLACK_LIST = {' '} BF16_BLACK_LIST = {' '}
_g_amp_state_ = None
def amp_state():
global _g_amp_state_
return _g_amp_state_
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
...@@ -240,6 +247,11 @@ def amp_guard(enable=True, ...@@ -240,6 +247,11 @@ def amp_guard(enable=True,
print(conv.dtype) # FP32 print(conv.dtype) # FP32
""" """
amp_state = locals()
global _g_amp_state_
original_state = _g_amp_state_
_g_amp_state_ = amp_state
# check amp_level: O0-O2 # check amp_level: O0-O2
level = level.upper() level = level.upper()
if not (level in ['O0', 'O1', 'O2']): if not (level in ['O0', 'O1', 'O2']):
...@@ -349,6 +361,7 @@ def amp_guard(enable=True, ...@@ -349,6 +361,7 @@ def amp_guard(enable=True,
yield yield
finally: finally:
if tracer: if tracer:
_g_amp_state_ = original_state
tracer._amp_level = original_amp_level tracer._amp_level = original_amp_level
tracer._set_amp_op_list(original_white_list, original_black_list) tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags) # set_flags(original_flags)
......
...@@ -20,6 +20,7 @@ import six ...@@ -20,6 +20,7 @@ import six
from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting
import paddle.nn as nn import paddle.nn as nn
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.autograd import PyLayer
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True}) fluid.set_flags({"FLAGS_cudnn_deterministic": True})
...@@ -1146,5 +1147,31 @@ class TestBf16(unittest.TestCase): ...@@ -1146,5 +1147,31 @@ class TestBf16(unittest.TestCase):
self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-1)) self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-1))
class TestPyLayerWithAmp(unittest.TestCase):
def test_pylayer(self):
class MyMM(PyLayer):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a.mm(b)
@staticmethod
def backward(ctx, grad):
a, b = ctx.saved_tensor()
# NOTE(zhiqiu): a and b is float32 now, while grad is fp16 when forward runs with auto_cast()
# thus, the mm operation raise errors because of the dtype of inputs are inconsistent
return grad.mm(b.t()), a.t().mm(grad)
x = paddle.rand([10, 10])
y = paddle.rand([10, 10])
x.stop_gradient = False
y.stop_gradient = False
with paddle.amp.auto_cast():
res = MyMM.apply(x, y)
loss = paddle.mean(res)
loss.backward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册