未验证 提交 eb12e627 编写于 作者: C cyber-pioneer 提交者: GitHub

fix eval branch of prim vjp of batch_norm in amp mode (#53598)

上级 aec4e38f
......@@ -1699,6 +1699,9 @@ void batch_norm_grad(const Tensor& x,
if (use_global_stats) {
auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad;
auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim);
if (x.dtype() == phi::DataType::FLOAT16) {
nchw_x_grad = cast<T>(nchw_x_grad, x.dtype());
}
set_output<T>(nchw_x_grad, x_grad);
} else {
auto part1 = scale * rsqrt_var;
......@@ -1732,6 +1735,9 @@ void batch_norm_grad(const Tensor& x,
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false);
if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data;
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}
set_output<T>(x_grad_data, x_grad);
} else {
auto part1 = scale * rsqrt_var;
......
......@@ -386,10 +386,12 @@ def apply_to_static(net, use_cinn):
class PrimeNet(paddle.nn.Layer):
def __init__(self, data_layout='NCHW'):
def __init__(self, data_layout='NCHW', is_test=False):
super().__init__()
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.bn = BatchNorm(4, act="relu", data_layout=data_layout)
self.bn = BatchNorm(
4, act="relu", data_layout=data_layout, is_test=is_test
)
def forward(self, x):
y = self.conv(x)
......@@ -408,10 +410,10 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False
def train(self, use_prim, data_layout="NCHW"):
def train(self, use_prim, data_layout="NCHW", is_test=False):
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = PrimeNet(data_layout)
net = PrimeNet(data_layout=data_layout, is_test=is_test)
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
......@@ -429,8 +431,19 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def test_amp_nchw(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False)
actual = self.train(True)
expected = self.train(use_prim=False)
actual = self.train(use_prim=True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
def test_amp_nchw_eval(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(use_prim=False, is_test=True)
actual = self.train(use_prim=True, is_test=True)
np.testing.assert_allclose(
expected,
actual,
......@@ -449,6 +462,19 @@ class TestPrimForwardAndBackward(unittest.TestCase):
atol=1e-3,
)
def test_amp_nhwc_eval(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(
use_prim=False, data_layout="NHWC", is_test=True
)
actual = self.train(use_prim=True, data_layout="NHWC", is_test=True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
class TestPrimEvalBranch(unittest.TestCase):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册