提交 0a95a44b 编写于 作者: K Kexin Zhao

add python batch norm inference test

上级 39c676e2
...@@ -125,8 +125,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -125,8 +125,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> functor; math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(dev_ctx, saved_mean, 0); functor(dev_ctx, saved_mean, static_cast<T>(0));
functor(dev_ctx, saved_variance, 0); functor(dev_ctx, saved_variance, static_cast<T>(0));
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
......
...@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set): ...@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op return backward_op
def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 2:
if data_format == "NCHW":
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
else:
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
if data_format == "NCHW":
n, c, h, w = x.shape
mean_tile = np.reshape(mean, (1, c, 1, 1))
mean_tile = np.tile(mean_tile, (n, 1, h, w))
var_tile = np.reshape(var, (1, c, 1, 1))
var_tile = np.tile(var_tile, (n, 1, h, w))
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
elif data_format == "NHWC":
normalized = (x - mean) / np.sqrt(var + epsilon)
y = normalized * scale + offset
else:
raise ValueError("Unknown data order.")
if len(x_shape) == 2:
y = np.reshape(y, x_shape)
return y
def _reference_training(x, scale, offset, epsilon, data_format): def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape x_shape = x.shape
if len(x_shape) == 2: if len(x_shape) == 2:
...@@ -155,7 +186,43 @@ def set_output_grad(scope, outputs, place, feed_dict=None): ...@@ -155,7 +186,43 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
__set_tensor__(output, data) __set_tensor__(output, data)
class TestBatchNormOp(OpTest): class TestBatchNormOpInference(OpTest):
def setUp(self):
self.dtype = np.float32
def test_python(self):
data_format = "NHWC"
epsilon = 0.00001
n, h, w, c = 2, 3, 4, 5
x_shape = [n, h, w, c]
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(self.dtype)
scale_val = np.random.random_sample(scale_shape).astype(self.dtype)
bias_val = np.random.random_sample(scale_shape).astype(self.dtype)
mean = np.zeros(scale_shape).astype(self.dtype)
variance = np.ones(scale_shape).astype(self.dtype)
# run forward
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, "NHWC")
# running N, C, H, W case
# should produce the same results
x_shape2 = [n, c, h, w]
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
epsilon, "NCHW")
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "inference output")
print 'python: NHWC, NCHW, inference checking passed'
class TestBatchNormOpTraining(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册