未验证 提交 f650e901 编写于 作者: Q qizhaoaoe 提交者: GitHub

fix dtype cast in amp for instance_norm. (#52765)

* fix dtype cast in amp.

* add test case and update docs.

* remove set_prim.
上级 2309aa58
...@@ -213,6 +213,9 @@ def pure_fp16_initialize(models): ...@@ -213,6 +213,9 @@ def pure_fp16_initialize(models):
paddle.nn.BatchNorm3D, paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm, paddle.nn.LayerNorm,
paddle.nn.SyncBatchNorm, paddle.nn.SyncBatchNorm,
paddle.nn.InstanceNorm1D,
paddle.nn.InstanceNorm2D,
paddle.nn.InstanceNorm3D,
), ),
): ):
continue continue
...@@ -522,7 +525,7 @@ def amp_decorate( ...@@ -522,7 +525,7 @@ def amp_decorate(
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm. When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm, InstanceNorm and LayerNorm.
Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode. Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode.
...@@ -530,7 +533,7 @@ def amp_decorate( ...@@ -530,7 +533,7 @@ def amp_decorate(
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None. models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp) O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm, InstanceNorm and LayerNorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
...@@ -741,7 +744,7 @@ def decorate( ...@@ -741,7 +744,7 @@ def decorate(
): ):
""" """
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
When level is O2(pure float16/bfloat16), the decorate will cast all parameters of models to float16/bfloat16, except BatchNorm and LayerNorm. When level is O2(pure float16/bfloat16), the decorate will cast all parameters of models to float16/bfloat16, except BatchNorm, InstanceNorm and LayerNorm.
Commonly, it is used together with `auto_cast` to achieve Pure float16/bfloat16 in imperative mode. Commonly, it is used together with `auto_cast` to achieve Pure float16/bfloat16 in imperative mode.
...@@ -749,7 +752,7 @@ def decorate( ...@@ -749,7 +752,7 @@ def decorate(
models(Layer|list of Layer): The defined models by user, models must be either a single model or a list of models. Default is None. models(Layer|list of Layer): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are 'O1' and 'O2': O1 represent mixed precision, the decorator will do nothing; level(str, optional): Auto mixed precision level. Accepted values are 'O1' and 'O2': O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure float16/bfloat16, the decorator will cast all parameters of models to float16/bfloat16, except BatchNorm and LayerNorm. Default is O1(amp) O2 represent Pure float16/bfloat16, the decorator will cast all parameters of models to float16/bfloat16, except BatchNorm, InstanceNorm and LayerNorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
......
...@@ -18,8 +18,9 @@ import numpy as np ...@@ -18,8 +18,9 @@ import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16 from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid import paddle.nn.functional as F
from paddle.fluid import Program, core, program_guard from paddle import fluid, nn
from paddle.fluid import Program, core, framework, program_guard
class TestInstanceNorm(unittest.TestCase): class TestInstanceNorm(unittest.TestCase):
...@@ -319,5 +320,64 @@ class TestInstanceNormBF16OP(OpTest): ...@@ -319,5 +320,64 @@ class TestInstanceNormBF16OP(OpTest):
) )
class PrimNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.instance_norm = nn.InstanceNorm2D(4)
def forward(self, x):
y = self.conv(x)
out = self.instance_norm(y)
res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0)
return res
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=False)
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimNet with @to_static + amp O2(with fp32)
"""
def setUp(self):
paddle.seed(2022)
paddle.disable_static()
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False
def train(self, use_amp, data_layout="NCHW"):
paddle.seed(2022)
net = PrimNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
net = apply_to_static(net, False)
if use_amp:
net = paddle.amp.decorate(models=net, level='O2')
with paddle.amp.auto_cast(enable=use_amp, level='O2'):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
return loss
def test_amp_nchw(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False)
actual = self.train(True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -99,6 +99,8 @@ def _keep_fp32_input(op, in_name): ...@@ -99,6 +99,8 @@ def _keep_fp32_input(op, in_name):
return in_name != 'X' return in_name != 'X'
if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32(): if op_type == 'layer_norm' and _keep_layer_norm_scale_bias_to_fp32():
return in_name != 'X' return in_name != 'X'
if op_type == 'instance_norm':
return in_name != 'X'
if op_type == 'fused_bn_add_activation': if op_type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'} return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit': if op_type == 'resnet_unit':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册