未验证 提交 e0b58212 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] fix batch_norm custom_vjp dtype (#51843)

* fix batch_norm custom_vjp dtype

* add nhwc test example

* fix typo
上级 f293e36c
...@@ -1282,7 +1282,7 @@ void batch_norm_grad(const Tensor& x, ...@@ -1282,7 +1282,7 @@ void batch_norm_grad(const Tensor& x,
auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw; auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false); auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto part2 = out_grad - mean_temp1 - x_sub_mean * mean_temp2; auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2; auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16) {
......
...@@ -349,14 +349,10 @@ def apply_to_static(net, use_cinn): ...@@ -349,14 +349,10 @@ def apply_to_static(net, use_cinn):
class PrimeNet(paddle.nn.Layer): class PrimeNet(paddle.nn.Layer):
def __init__(self): def __init__(self, data_layout='NCHW'):
super().__init__() super().__init__()
self.conv = nn.Conv2D(4, 2, (3, 3), bias_attr=False) self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.bn = BatchNorm(2, act="relu") self.bn = BatchNorm(4, act="relu", data_layout=data_layout)
self.run_mean = zeros([2])
self.run_var = ones([2])
self.scale = ones([2])
self.bias = ones([2])
def forward(self, x): def forward(self, x):
y = self.conv(x) y = self.conv(x)
...@@ -372,13 +368,13 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -372,13 +368,13 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([4, 4, 6, 6], dtype="float32") self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False self.x.stop_gradient = False
def train(self, use_prim): def train(self, use_prim, data_layout="NCHW"):
core._set_prim_all_enabled(use_prim) core._set_prim_all_enabled(use_prim)
paddle.seed(2022) paddle.seed(2022)
net = PrimeNet() net = PrimeNet(data_layout)
sgd = paddle.optimizer.SGD( sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters() learning_rate=0.1, parameters=net.parameters()
) )
...@@ -394,7 +390,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -394,7 +390,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
sgd.clear_grad() sgd.clear_grad()
return loss return loss
def test_amp(self): def test_amp_nchw(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace): if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False) expected = self.train(False)
actual = self.train(True) actual = self.train(True)
...@@ -405,6 +401,17 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -405,6 +401,17 @@ class TestPrimForwardAndBackward(unittest.TestCase):
atol=1e-3, atol=1e-3,
) )
def test_amp_nhwc(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(use_prim=False, data_layout="NHWC")
actual = self.train(use_prim=True, data_layout="NHWC")
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册