提交 ad91bfe5 编写于 作者: T Tao Luo 提交者: qingqing01

fix a bug in test_batch_norm_op.py (#10094)

上级 eb8e14c9
...@@ -100,6 +100,9 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format): ...@@ -100,6 +100,9 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format):
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
# transfer from (N, C, H, W) to (N, H, W, C) to simplify computation # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
if data_format != "NCHW" and data_format != "NHWC":
raise ValueError("Unknown data order.")
if data_format == "NCHW": if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1)) x = np.transpose(x, (0, 2, 3, 1))
y_grad = np.transpose(y_grad, (0, 2, 3, 1)) y_grad = np.transpose(y_grad, (0, 2, 3, 1))
...@@ -304,7 +307,7 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -304,7 +307,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
# run backward # run backward
y_grad = np.random.random_sample(shape).astype(np.float32) y_grad = np.random.random_sample(shape).astype(np.float32)
x_grad, scale_grad, bias_grad = _reference_grad( x_grad, scale_grad, bias_grad = _reference_grad(
x, y_grad, scale, saved_mean, var_ref, epsilon, data_format) x, y_grad, scale, saved_mean, var_ref, epsilon, data_layout)
var_dict = locals() var_dict = locals()
var_dict['y@GRAD'] = y_grad var_dict['y@GRAD'] = y_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册