未验证 提交 eaa41857 编写于 作者: Z Zhuoyuan 提交者: GitHub

Merge pull request #5103 from zchen0211/batch-norm-latest

Batch norm latest
...@@ -117,9 +117,6 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -117,9 +117,6 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
math::SetConstant<platform::GPUPlace, T> functor; math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), saved_mean, 0); functor(ctx.device_context(), saved_mean, 0);
functor(ctx.device_context(), saved_variance, 0); functor(ctx.device_context(), saved_variance, 0);
// FIXME(qiao) should not set zero self
functor(ctx.device_context(), mean_out, 0);
functor(ctx.device_context(), variance_out, 0);
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
...@@ -211,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T> ...@@ -211,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T>
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif #endif
std::vector<int> dims = {N, C, H, W, D}; std::vector<int> dims;
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C}; std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
......
...@@ -21,8 +21,26 @@ def get_backward_op(scope, op, no_grad_set): ...@@ -21,8 +21,26 @@ def get_backward_op(scope, op, no_grad_set):
def _reference_training(x, scale, offset, epsilon, data_format): def _reference_training(x, scale, offset, epsilon, data_format):
if data_format != "NHWC": if data_format == "NCHW":
raise ValueError("data_format must be NHWC, got %s." % data_format) n, c, h, w = x.shape
x_square = x * x
x_square_sum = np.sum(x_square, (0, 2, 3))
x_sum = np.sum(x, axis=(0, 2, 3))
element_count = np.size(x) / int(np.shape(x)[1])
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
mean_tile = np.reshape(mean, (1, c, 1, 1))
mean_tile = np.tile(mean_tile, (n, 1, h, w))
var_tile = np.reshape(var, (1, c, 1, 1))
var_tile = np.tile(var_tile, (n, 1, h, w))
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
return y, mean, var
elif data_format == "NHWC":
x_square = x * x x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2)) x_square_sum = np.sum(x_square, (0, 1, 2))
x_sum = np.sum(x, axis=(0, 1, 2)) x_sum = np.sum(x, axis=(0, 1, 2))
...@@ -31,6 +49,8 @@ def _reference_training(x, scale, offset, epsilon, data_format): ...@@ -31,6 +49,8 @@ def _reference_training(x, scale, offset, epsilon, data_format):
var = x_square_sum / element_count - mean * mean var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon) normalized = (x - mean) / np.sqrt(var + epsilon)
return (normalized * scale + offset), mean, var return (normalized * scale + offset), mean, var
else:
raise ValueError("Unknown data order.")
def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
...@@ -43,8 +63,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): ...@@ -43,8 +63,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
# grad_x = # grad_x =
# 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format) # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
grad_y = np.transpose(grad_y, (0, 2, 3, 1))
# raise ValueError("data_format must be NHWC, got %s." % data_format)
grad_x = scale * (grad_y - np.mean( grad_x = scale * (grad_y - np.mean(
grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean( grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean(
grad_y * (x - mean), axis=(0, 1, 2)) / grad_y * (x - mean), axis=(0, 1, 2)) /
...@@ -52,6 +77,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): ...@@ -52,6 +77,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon), grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon),
axis=(0, 1, 2)) axis=(0, 1, 2))
grad_offset = np.sum(grad_y, axis=(0, 1, 2)) grad_offset = np.sum(grad_y, axis=(0, 1, 2))
# transfer back to N, C, H, W
if data_format == "NCHW":
grad_x = np.transpose(grad_x, (0, 3, 1, 2))
x = np.transpose(x, (0, 3, 1, 2))
grad_y = np.transpose(grad_y, (0, 3, 1, 2))
return grad_x, grad_scale, grad_offset return grad_x, grad_scale, grad_offset
...@@ -65,61 +96,135 @@ def create_or_get_tensor(scope, var_name, var, place): ...@@ -65,61 +96,135 @@ def create_or_get_tensor(scope, var_name, var, place):
return tensor return tensor
def set_output_grad(scope, outputs, place): def set_output_grad(scope, outputs, place, feed_dict=None):
def __set_tensor__(name): def __set_tensor__(name, data=None):
out_tensor = scope.find_var(name).get_tensor() out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor() grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype() out_dtype = out_tensor.dtype()
if data is None:
if out_dtype == core.DataType.FP64: if out_dtype == core.DataType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64) data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32: elif out_dtype == core.DataType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32) data = np.ones(out_tensor.shape(), dtype=np.float32)
else: else:
raise ValueError("Not supported data type " + str(out_dtype)) raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place) grad_tensor.set(data, place)
for output in outputs: for output in outputs:
__set_tensor__(output) data = None
if output in feed_dict:
data = feed_dict[output]
__set_tensor__(output, data)
class TestBatchNormOp(OpTest): class TestBatchNormOp(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_python(self):
data_format = "NHWC"
epsilon = 0.00001
momentum = 0.9
# N, H, W, C: 2, 3, 4, 2
n, h, w, c = 2, 3, 4, 2
x_shape = [n, h, w, c]
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
# run forward
y_out, saved_mean, var_ref = _reference_training(
x_val, scale_val, bias_val, epsilon, "NHWC")
#
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# running N, C, H, W case
# should produce the same results
x_shape2 = [n, c, h, w]
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
y_out2, saved_mean2, var_ref2 = _reference_training(
x_val2, scale_val, bias_val, epsilon, "NCHW")
self.__assert_close(saved_mean, saved_mean2, "batch mean")
self.__assert_close(var_ref, var_ref2, "batch variance")
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "batch variance")
print 'python: NHWC, NCHW, forward checking passed'
# test backward now
# NHWC
self.y_grad = np.random.random_sample(x_shape).astype(np.float32)
y_grad = self.y_grad
# y_grad = np.ones(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, "NHWC")
# NCHW
y_grad2 = np.transpose(y_grad, (0, 3, 1, 2))
# y_grad2 = np.ones(x_shape2).astype(np.float32)
x_grad_ref2, scale_grad_ref2, bias_grad_ref2 = _reference_grad(
x_val2, y_grad2, scale_val, saved_mean2, var_ref2, epsilon, "NCHW")
self.__assert_close(scale_grad_ref, scale_grad_ref2, "scale gradient")
self.__assert_close(bias_grad_ref, bias_grad_ref2, "bias gradient")
x_grad_transpose = np.transpose(x_grad_ref2, (0, 2, 3, 1))
self.__assert_close(x_grad_ref, x_grad_transpose, "x gradient")
print 'python: NHWC, NCHW, backward checking passed'
def test_forward_backward(self): def test_forward_backward(self):
def test_with_place(place, tensor_format):
# attr # attr
data_format = "NHWC"
epsilon = 0.00001 epsilon = 0.00001
momentum = 0.9 momentum = 0.9
channel_num = 2 # N, H, W, C: 12, 3, 4, 2
x_shape = [2, 3, 4, channel_num] n, h, w, c = 2, 3, 4, 2
scale_shape = [channel_num]
if data_format == "NHWC":
x_shape = [n, h, w, c]
elif data_format == "NCHW":
x_shape = [n, c, h, w]
else:
raise ValueError("Unknown data type.")
scale_shape = [c]
# input
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32) mean = np.zeros(scale_shape).astype(np.float32)
variance = np.zeros(scale_shape).astype(np.float32) variance = np.ones(scale_shape).astype(np.float32)
# run forward # run forward
y_out, saved_mean, var_ref = _reference_training( y_out, saved_mean, var_ref = _reference_training(
x_val, scale_val, bias_val, epsilon, data_format) x_val, scale_val, bias_val, epsilon, data_format)
# run backward # update moving mean and variance
mean_out = saved_mean * (1 - momentum) mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1 - momentum) variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1 / np.sqrt(var_ref + epsilon) saved_variance = 1. / np.sqrt(var_ref + epsilon)
# for gradient test # for gradient test
y_grad = np.ones(x_shape).astype(np.float32) # y_grad = np.ones(x_shape).astype(np.float32)
y_grad = np.zeros(x_shape).astype(np.float32)
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
data_format)
def test_with_place(place):
scope = core.Scope() scope = core.Scope()
# create input # create input
...@@ -157,7 +262,7 @@ class TestBatchNormOp(OpTest): ...@@ -157,7 +262,7 @@ class TestBatchNormOp(OpTest):
SavedVariance="saved_variance", SavedVariance="saved_variance",
# attrs # attrs
is_test=False, is_test=False,
tensor_format=data_format, tensor_format=tensor_format,
momentum=momentum, momentum=momentum,
epsilon=epsilon) epsilon=epsilon)
...@@ -170,20 +275,21 @@ class TestBatchNormOp(OpTest): ...@@ -170,20 +275,21 @@ class TestBatchNormOp(OpTest):
self.__assert_close(saved_variance_tensor, saved_variance, self.__assert_close(saved_variance_tensor, saved_variance,
"saved_variance") "saved_variance")
self.__assert_close(mean_out_tensor, mean_out, "mean_out") self.__assert_close(mean_out_tensor, mean_out, "mean_out")
# FIXME(qiao) figure out why with cuDNN variance_out have a higher error rate
if isinstance(place, core.GPUPlace): if isinstance(place, core.GPUPlace):
atol = 5e-2 atol = 5e-2
else: else:
atol = 1e-4 atol = 1e-4
self.__assert_close(variance_out_tensor, variance_out, self.__assert_close(variance_out_tensor, variance_out,
"variance_out", atol) "variance_out", atol)
print "op test forward passed: ", str(place), tensor_format
# run backward # run backward
batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set()) batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
set_output_grad( set_output_grad(
scope, scope,
["y_out", "mean", "variance", "saved_mean", "saved_variance"], ["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place) place,
feed_dict={"y_out": y_grad})
batch_norm_op_grad.run(scope, ctx) batch_norm_op_grad.run(scope, ctx)
x_grad_tensor = create_or_get_tensor(scope, x_grad_tensor = create_or_get_tensor(scope,
...@@ -200,12 +306,14 @@ class TestBatchNormOp(OpTest): ...@@ -200,12 +306,14 @@ class TestBatchNormOp(OpTest):
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
print "op test backward passed: ", str(place), tensor_format
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.GPUPlace(0)) places.append(core.GPUPlace(0))
for place in places: for place in places:
test_with_place(place) for data_format in ["NCHW", "NHWC"]:
test_with_place(place, data_format)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册