未验证 提交 20befdef 编写于 作者: C cyber-pioneer 提交者: GitHub

fix eval branch of composite rule of batch_norm (#52154)

上级 bcea3b89
...@@ -413,5 +413,35 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -413,5 +413,35 @@ class TestPrimForwardAndBackward(unittest.TestCase):
) )
class TestPrimEvalBranch(unittest.TestCase):
"""
Test eval branch or composite rule of batch_norm.
"""
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([4, 2, 6, 6], dtype="float32")
self.x.stop_gradient = False
def train(self, use_prim):
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = BatchNorm(2, is_test=True)
net = apply_to_static(net, False)
out = net(self.x)
loss = paddle.mean(out)
return loss
def test_eval_branch(self):
expected = self.train(False)
actual = self.train(True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-6,
atol=1e-6,
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -86,11 +86,6 @@ def composite_batchnorm( ...@@ -86,11 +86,6 @@ def composite_batchnorm(
feature_axis = ( feature_axis = (
1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1 1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1
) )
if use_global_stats is None:
use_global_stats = is_test
trainable_statistics = False
else:
trainable_statistics = not use_global_stats
use_run_stat = (is_test and (not trainable_statistics)) or use_global_stats use_run_stat = (is_test and (not trainable_statistics)) or use_global_stats
reduce_axes = tuple(i for i in range(len(x.shape)) if i != feature_axis) reduce_axes = tuple(i for i in range(len(x.shape)) if i != feature_axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册