diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 979be562db93a519d93275bfde43d08a8b7c6bc1..a9d91b6e854898a5e037078aa6165beaea37f541 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1699,6 +1699,9 @@ void batch_norm_grad(const Tensor& x, if (use_global_stats) { auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; auto nchw_x_grad = transpose(nhwc_x_grad, nhwc_to_nchw_dim); + if (x.dtype() == phi::DataType::FLOAT16) { + nchw_x_grad = cast(nchw_x_grad, x.dtype()); + } set_output(nchw_x_grad, x_grad); } else { auto part1 = scale * rsqrt_var; @@ -1732,6 +1735,9 @@ void batch_norm_grad(const Tensor& x, sum(out_grad_data * x_sub_mean, reduce_axis, dtype, false); if (use_global_stats) { auto x_grad_data = scale * rsqrt_var * out_grad_data; + if (x.dtype() == phi::DataType::FLOAT16) { + x_grad_data = cast(x_grad_data, x.dtype()); + } set_output(x_grad_data, x_grad); } else { auto part1 = scale * rsqrt_var; diff --git a/test/prim/composite_ops/test_composite_batch_norm.py b/test/prim/composite_ops/test_composite_batch_norm.py index e61269c8f22eeaf0053cf3c01c90c8331c75d778..497d8e4d0a505687dd186976d407efb0077bce0c 100644 --- a/test/prim/composite_ops/test_composite_batch_norm.py +++ b/test/prim/composite_ops/test_composite_batch_norm.py @@ -386,10 +386,12 @@ def apply_to_static(net, use_cinn): class PrimeNet(paddle.nn.Layer): - def __init__(self, data_layout='NCHW'): + def __init__(self, data_layout='NCHW', is_test=False): super().__init__() self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False) - self.bn = BatchNorm(4, act="relu", data_layout=data_layout) + self.bn = BatchNorm( + 4, act="relu", data_layout=data_layout, is_test=is_test + ) def forward(self, x): y = self.conv(x) @@ -408,10 +410,10 @@ class TestPrimForwardAndBackward(unittest.TestCase): self.x = paddle.randn([4, 2, 6, 6], dtype="float32") self.x.stop_gradient = False - def train(self, use_prim, data_layout="NCHW"): + def train(self, use_prim, data_layout="NCHW", is_test=False): core._set_prim_all_enabled(use_prim) paddle.seed(2022) - net = PrimeNet(data_layout) + net = PrimeNet(data_layout=data_layout, is_test=is_test) sgd = paddle.optimizer.SGD( learning_rate=0.1, parameters=net.parameters() ) @@ -429,8 +431,19 @@ class TestPrimForwardAndBackward(unittest.TestCase): def test_amp_nchw(self): if not isinstance(framework._current_expected_place(), core.CPUPlace): - expected = self.train(False) - actual = self.train(True) + expected = self.train(use_prim=False) + actual = self.train(use_prim=True) + np.testing.assert_allclose( + expected, + actual, + rtol=1e-3, + atol=1e-3, + ) + + def test_amp_nchw_eval(self): + if not isinstance(framework._current_expected_place(), core.CPUPlace): + expected = self.train(use_prim=False, is_test=True) + actual = self.train(use_prim=True, is_test=True) np.testing.assert_allclose( expected, actual, @@ -449,6 +462,19 @@ class TestPrimForwardAndBackward(unittest.TestCase): atol=1e-3, ) + def test_amp_nhwc_eval(self): + if not isinstance(framework._current_expected_place(), core.CPUPlace): + expected = self.train( + use_prim=False, data_layout="NHWC", is_test=True + ) + actual = self.train(use_prim=True, data_layout="NHWC", is_test=True) + np.testing.assert_allclose( + expected, + actual, + rtol=1e-3, + atol=1e-3, + ) + class TestPrimEvalBranch(unittest.TestCase): """