未验证 提交 e0b58212 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] fix batch_norm custom_vjp dtype (#51843)

* fix batch_norm custom_vjp dtype

* add nhwc test example

* fix typo
上级 f293e36c
......@@ -1282,7 +1282,7 @@ void batch_norm_grad(const Tensor& x,
auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto part2 = out_grad - mean_temp1 - x_sub_mean * mean_temp2;
auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) {
......
......@@ -349,14 +349,10 @@ def apply_to_static(net, use_cinn):
class PrimeNet(paddle.nn.Layer):
def __init__(self):
def __init__(self, data_layout='NCHW'):
super().__init__()
self.conv = nn.Conv2D(4, 2, (3, 3), bias_attr=False)
self.bn = BatchNorm(2, act="relu")
self.run_mean = zeros([2])
self.run_var = ones([2])
self.scale = ones([2])
self.bias = ones([2])
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.bn = BatchNorm(4, act="relu", data_layout=data_layout)
def forward(self, x):
y = self.conv(x)
......@@ -372,13 +368,13 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([4, 4, 6, 6], dtype="float32")
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False
def train(self, use_prim):
def train(self, use_prim, data_layout="NCHW"):
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = PrimeNet()
net = PrimeNet(data_layout)
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
......@@ -394,7 +390,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
sgd.clear_grad()
return loss
def test_amp(self):
def test_amp_nchw(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False)
actual = self.train(True)
......@@ -405,6 +401,17 @@ class TestPrimForwardAndBackward(unittest.TestCase):
atol=1e-3,
)
def test_amp_nhwc(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(use_prim=False, data_layout="NHWC")
actual = self.train(use_prim=True, data_layout="NHWC")
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册