“e983cc90fcee4e5b73bce9d4853b85aac4661e3a”上不存在“...paddle/v2/fluid/tests/unittests/test_protobuf_descs.py”
未验证 提交 95a7bcf9 编写于 作者: C cyber-pioneer 提交者: GitHub

fix eval branch of prim vjp of batch_norm in amp mode (#53594)

上级 2cf4a04a
...@@ -1428,6 +1428,9 @@ void batch_norm_grad(const Tensor& x, ...@@ -1428,6 +1428,9 @@ void batch_norm_grad(const Tensor& x,
if (use_global_stats) { if (use_global_stats) {
auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad;
auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim); auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim);
if (x.dtype() == phi::DataType::FLOAT16) {
nchw_x_grad = cast<T>(nchw_x_grad, x.dtype());
}
set_output<T>(nchw_x_grad, x_grad); set_output<T>(nchw_x_grad, x_grad);
} else { } else {
auto part1 = scale * rsqrt_var; auto part1 = scale * rsqrt_var;
...@@ -1461,6 +1464,9 @@ void batch_norm_grad(const Tensor& x, ...@@ -1461,6 +1464,9 @@ void batch_norm_grad(const Tensor& x,
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false); sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false);
if (use_global_stats) { if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data; auto x_grad_data = scale * rsqrt_var * out_grad_data;
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}
set_output<T>(x_grad_data, x_grad); set_output<T>(x_grad_data, x_grad);
} else { } else {
auto part1 = scale * rsqrt_var; auto part1 = scale * rsqrt_var;
......
...@@ -386,10 +386,12 @@ def apply_to_static(net, use_cinn): ...@@ -386,10 +386,12 @@ def apply_to_static(net, use_cinn):
class PrimeNet(paddle.nn.Layer): class PrimeNet(paddle.nn.Layer):
def __init__(self, data_layout='NCHW'): def __init__(self, data_layout='NCHW', is_test=False):
super().__init__() super().__init__()
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False) self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.bn = BatchNorm(4, act="relu", data_layout=data_layout) self.bn = BatchNorm(
4, act="relu", data_layout=data_layout, is_test=is_test
)
def forward(self, x): def forward(self, x):
y = self.conv(x) y = self.conv(x)
...@@ -408,10 +410,10 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -408,10 +410,10 @@ class TestPrimForwardAndBackward(unittest.TestCase):
self.x = paddle.randn([4, 2, 6, 6], dtype="float32") self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False self.x.stop_gradient = False
def train(self, use_prim, data_layout="NCHW"): def train(self, use_prim, data_layout="NCHW", is_test=False):
core._set_prim_all_enabled(use_prim) core._set_prim_all_enabled(use_prim)
paddle.seed(2022) paddle.seed(2022)
net = PrimeNet(data_layout) net = PrimeNet(data_layout=data_layout, is_test=is_test)
sgd = paddle.optimizer.SGD( sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters() learning_rate=0.1, parameters=net.parameters()
) )
...@@ -429,8 +431,19 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -429,8 +431,19 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def test_amp_nchw(self): def test_amp_nchw(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace): if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False) expected = self.train(use_prim=False)
actual = self.train(True) actual = self.train(use_prim=True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
def test_amp_nchw_eval(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(use_prim=False, is_test=True)
actual = self.train(use_prim=True, is_test=True)
np.testing.assert_allclose( np.testing.assert_allclose(
expected, expected,
actual, actual,
...@@ -449,6 +462,19 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -449,6 +462,19 @@ class TestPrimForwardAndBackward(unittest.TestCase):
atol=1e-3, atol=1e-3,
) )
def test_amp_nhwc_eval(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(
use_prim=False, data_layout="NHWC", is_test=True
)
actual = self.train(use_prim=True, data_layout="NHWC", is_test=True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
class TestPrimEvalBranch(unittest.TestCase): class TestPrimEvalBranch(unittest.TestCase):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册