未验证 提交 c975fe1b 编写于 作者: Q Qiao Longfei 提交者: GitHub

batch norm support matrix input (#5980)

* batch norm support matrix input

* update gpu code

* format code
上级 23b3fef0
......@@ -62,13 +62,14 @@ class BatchNormOp : public framework::OperatorWithKernel {
const auto x_dims = ctx->GetInputDim("X");
const TensorFormat tensor_format =
StringToTensorFormat(ctx->Attrs().Get<std::string>("tensor_format"));
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"Input X must have 2 to 5 dimensions.");
const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"Input X must have 3 to 5 dimensions.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
......@@ -146,8 +147,8 @@ class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
const int N = x_dims[0];
const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1]
......@@ -339,8 +340,8 @@ class BatchNormGradKernel<platform::CPUPlace, T>
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
const int N = x_dims[0];
const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1]
......
......@@ -29,14 +29,21 @@ void ExtractNCWHD(const framework::DDim &dims,
const TensorFormat &tensor_format, int *N, int *C, int *H,
int *W, int *D) {
*N = dims[0];
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
: 1;
if (dims.size() == 2) {
*C = dims[1];
*H = 1;
*W = 1;
*D = 1;
} else {
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
: 1;
}
}
template <typename T>
......@@ -56,8 +63,8 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
// NCHW [batch_size, in_channels, in_height, in_width]
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
......@@ -180,8 +187,8 @@ class BatchNormGradKernel<platform::GPUPlace, T>
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
......
......@@ -69,8 +69,7 @@ def vgg16_bn_drop(input):
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1 = fluid.layers.fc(input=drop, size=512, act=None)
reshape1 = fluid.layers.reshape(x=fc1, shape=list(fc1.shape + (1, 1)))
bn = fluid.layers.batch_norm(input=reshape1, act='relu')
bn = fluid.layers.batch_norm(input=fc1, act='relu')
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2 = fluid.layers.fc(input=drop2, size=512, act=None)
return fc2
......
......@@ -21,6 +21,13 @@ def get_backward_op(scope, op, no_grad_set):
def _reference_training(x, scale, offset, epsilon, data_format):
x_shape = x.shape
if len(x_shape) == 2:
if data_format == "NCHW":
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
else:
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
if data_format == "NCHW":
n, c, h, w = x.shape
x_square = x * x
......@@ -39,6 +46,8 @@ def _reference_training(x, scale, offset, epsilon, data_format):
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
if len(x_shape) == 2:
y = np.reshape(y, (y.shape[0], y.shape[1]))
return y, mean, var
elif data_format == "NHWC":
x_square = x * x
......@@ -48,7 +57,10 @@ def _reference_training(x, scale, offset, epsilon, data_format):
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon)
return (normalized * scale + offset), mean, var
y = normalized * scale + offset
if len(x_shape) == 2:
y = np.reshape(y, x_shape)
return y, mean, var
else:
raise ValueError("Unknown data order.")
......@@ -65,6 +77,18 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
# transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
x_shape = x.shape
if len(x_shape) == 2:
if data_format == "NCHW":
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
grad_y = np.reshape(grad_y,
(grad_y.shape[0], grad_y.shape[1], 1, 1))
else:
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
grad_y = np.reshape(grad_y,
(grad_y.shape[0], 1, 1, grad_y.shape[1]))
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
grad_y = np.transpose(grad_y, (0, 2, 3, 1))
......@@ -83,6 +107,9 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
grad_x = np.transpose(grad_x, (0, 3, 1, 2))
x = np.transpose(x, (0, 3, 1, 2))
grad_y = np.transpose(grad_y, (0, 3, 1, 2))
if len(x_shape) == 2:
grad_x = np.reshape(grad_x, x_shape)
return grad_x, grad_scale, grad_offset
......@@ -127,7 +154,7 @@ class TestBatchNormOp(OpTest):
momentum = 0.9
# N, H, W, C: 2, 3, 4, 2
n, h, w, c = 2, 3, 4, 2
n, h, w, c = 2, 3, 4, 5
x_shape = [n, h, w, c]
scale_shape = [c]
......@@ -184,20 +211,23 @@ class TestBatchNormOp(OpTest):
print 'python: NHWC, NCHW, backward checking passed'
def test_forward_backward(self):
def test_with_place(place, tensor_format):
def test_with_place(place, tensor_format, shape):
# attr
epsilon = 0.00001
momentum = 0.9
# N, H, W, C: 12, 3, 4, 2
n, h, w, c = 2, 3, 4, 2
if data_format == "NHWC":
x_shape = [n, h, w, c]
elif data_format == "NCHW":
x_shape = [n, c, h, w]
if len(shape) == 2:
x_shape = shape
c = shape[1]
else:
raise ValueError("Unknown data type.")
# n, h, w, c = 2, 3, 4, 2
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
if data_format == "NHWC":
x_shape = [n, h, w, c]
elif data_format == "NCHW":
x_shape = [n, c, h, w]
else:
raise ValueError("Unknown data type.")
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(np.float32)
......@@ -219,7 +249,10 @@ class TestBatchNormOp(OpTest):
# for gradient test
# y_grad = np.ones(x_shape).astype(np.float32)
y_grad = np.zeros(x_shape).astype(np.float32)
y_grad[0, 0, 0, 0] = 1.
if len(y_grad.shape) == 2:
y_grad[0, 0] = 1.
else:
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
......@@ -313,7 +346,8 @@ class TestBatchNormOp(OpTest):
places.append(core.GPUPlace(0))
for place in places:
for data_format in ["NCHW", "NHWC"]:
test_with_place(place, data_format)
test_with_place(place, data_format, [2, 3, 4, 5])
test_with_place(place, data_format, [2, 3])
if __name__ == '__main__':
......
import unittest
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid as fluid
import paddle.v2.fluid.nets as nets
from paddle.v2.fluid.framework import Program
......@@ -29,27 +29,35 @@ class TestLayer(unittest.TestCase):
def test_batch_norm_layer(self):
main_program = Program()
startup_program = Program()
images = layers.data(
images = fluid.layers.data(
name='pixel',
shape=[3, 48, 48],
dtype='float32',
main_program=main_program)
layers.batch_norm(
hidden1 = fluid.layers.batch_norm(
input=images,
main_program=main_program,
startup_program=startup_program)
hidden2 = fluid.layers.fc(input=hidden1,
size=128,
act='relu',
main_program=main_program)
hidden3 = fluid.layers.batch_norm(
input=hidden2,
main_program=main_program,
startup_program=startup_program)
# print str(main_program)
print str(main_program)
def test_dropout_layer(self):
main_program = Program()
startup_program = Program()
images = layers.data(
images = fluid.layers.data(
name='pixel',
shape=[3, 48, 48],
dtype='float32',
main_program=main_program)
layers.dropout(
fluid.layers.dropout(
x=images,
dropout_prob=0.5,
main_program=main_program,
......@@ -61,7 +69,7 @@ class TestLayer(unittest.TestCase):
main_program = Program()
startup_program = Program()
images = layers.data(
images = fluid.layers.data(
name='pixel',
shape=[3, 48, 48],
dtype='float32',
......@@ -77,19 +85,19 @@ class TestLayer(unittest.TestCase):
def test_elementwise_add_with_act(self):
main_program = Program()
startup_program = Program()
image1 = layers.data(
image1 = fluid.layers.data(
name='pixel1',
shape=[3, 48, 48],
dtype='float32',
main_program=main_program,
startup_program=startup_program)
image2 = layers.data(
image2 = fluid.layers.data(
name='pixel2',
shape=[3, 48, 48],
dtype='float32',
main_program=main_program,
startup_program=startup_program)
out = layers.elementwise_add(
out = fluid.layers.elementwise_add(
x=image1,
y=image2,
act='relu',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册