提交 f456a4e9 编写于 作者: Z zchen0211

batch-norm forward backward nchw, nhwc passed

上级 03789a7d
......@@ -184,8 +184,8 @@ class TestBatchNormOp(OpTest):
print 'python: NHWC, NCHW, backward checking passed'
def test_forward_backward(self):
def test_with_place(place, tensor_format):
# attr
data_format = "NCHW"
epsilon = 0.00001
momentum = 0.9
......@@ -222,9 +222,9 @@ class TestBatchNormOp(OpTest):
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format)
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
data_format)
def test_with_place(place, tensor_format=data_format):
scope = core.Scope()
# create input
......@@ -275,14 +275,13 @@ class TestBatchNormOp(OpTest):
self.__assert_close(saved_variance_tensor, saved_variance,
"saved_variance")
self.__assert_close(mean_out_tensor, mean_out, "mean_out")
# FIXME(qiao) figure out why with cuDNN variance_out have a higher error rate
if isinstance(place, core.GPUPlace):
atol = 5e-2
else:
atol = 1e-4
self.__assert_close(variance_out_tensor, variance_out,
"variance_out", atol)
print "op test forward passed: ", tensor_format
print "op test forward passed: ", str(place), tensor_format
# run backward
batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
......@@ -307,14 +306,14 @@ class TestBatchNormOp(OpTest):
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
print "op test backward passed: ", tensor_format
print "op test backward passed: ", str(place), tensor_format
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.GPUPlace(0))
for place in places:
test_with_place(place)
print "test forward passed"
for data_format in ["NCHW", "NHWC"]:
test_with_place(place, data_format)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册