From e82f1008a82232936529ec4bba70a59880915912 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 17 Jan 2018 00:42:20 +0800 Subject: [PATCH] Finish block expand op 1. Add lod to output 2. Fix im2col arguments list 3. Refine code and doc 4. Fix output shape --- paddle/operators/block_expand_op.cc | 119 +++++++---- paddle/operators/block_expand_op.h | 140 ++++++------ .../v2/fluid/tests/test_block_expand_op.py | 202 ++++++++---------- 3 files changed, 239 insertions(+), 222 deletions(-) diff --git a/paddle/operators/block_expand_op.cc b/paddle/operators/block_expand_op.cc index f25cc4f9d..317a43bb7 100644 --- a/paddle/operators/block_expand_op.cc +++ b/paddle/operators/block_expand_op.cc @@ -32,37 +32,27 @@ class BlockExpandOp : public framework::OperatorWithKernel { auto in_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input(X) format must be 4D tensor, eg., NCHW."); - PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1."); - int block_height = ctx->Attrs().Get("blockHeight"); - int block_width = ctx->Attrs().Get("blockWidth"); - int stride_height = ctx->Attrs().Get("strideHeight"); - int stride_width = ctx->Attrs().Get("strideWidth"); - int padding_height = ctx->Attrs().Get("paddingHeight"); - int padding_width = ctx->Attrs().Get("paddingWidth"); + int block_height = ctx->Attrs().Get("block_height"); + int block_width = ctx->Attrs().Get("block_width"); + int stride_height = ctx->Attrs().Get("stride_height"); + int stride_width = ctx->Attrs().Get("stride_width"); + int padding_height = ctx->Attrs().Get("padding_height"); + int padding_width = ctx->Attrs().Get("padding_width"); - int N = in_dim[0]; - int C = in_dim[1]; + int batch_size = in_dim[0]; + int img_channels = in_dim[1]; int img_height = in_dim[2]; int img_width = in_dim[3]; - int output_height = 0; - int output_width = 0; + int output_height = get_output_size(img_height, block_height, stride_height, + padding_height); + int output_width = + get_output_size(img_width, block_width, stride_width, padding_width); - get_blockexpand_output_shape(img_height, img_width, block_height, - block_width, stride_height, stride_width, - padding_height, padding_width, output_height, - output_width); - - // The result of im2col is [output_height, output_width, - // inputChannels, filterHeight, filterWidth], and it is easy to - // reshape into [seqLength, stepSize], where seqLength is equal - // output_height * output_width, stepSize is equal - // input_channels * blockHeight * blockWidth - ctx->SetOutputDim( - "Out", {N, output_height, output_width, C, block_height, block_width}); - - // ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("Out", {batch_size * output_height * output_width, + img_channels * block_height * block_width}); + // TODO(wanghaoshuang): cal lod in complie time } }; @@ -79,28 +69,69 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { W: width )DOC"); AddOutput("Out", "(LodTensor)The output data of block_expand op,"); - AddAttr("blockHeight", "(int)height of block."); - AddAttr("blockWidth", "(int)width of block."); - AddAttr("strideHeight", "(int)height of stride."); - AddAttr("strideWidth", "(int)width of stride."); - AddAttr("paddingHeight", "(int)height of padding."); - AddAttr("paddingWidth", "(int)width of padding."); + AddAttr("block_height", "(int)height of block."); + AddAttr("block_width", "(int)width of block."); + AddAttr("stride_height", "(int)height of stride."); + AddAttr("stride_width", "(int)width of stride."); + AddAttr("padding_height", "(int)height of padding."); + AddAttr("padding_width", "(int)width of padding."); AddComment(R"DOC( Expand feature map to minibatch matrix. - matirx height is: output_height * output_width -- matrix width is: blockHeight * blockWidth * channels - -output_height = - 1 + (2 * paddingHeight + img_height - blockHeight + strideHeight - 1) / - strideHeight; -output_width = - 1 + (2 * paddingWidth + img_width - blockWidth + strideWidth - 1) / - strideWidth; - -The expand method is the same with ExpandConvLayer, but saved the transposed -value. After expanding, The number of time steps are output_height * output_width -and the dimension of each time step is blockHeight * blockWidth * channels. -This layer can be used after convolution neural network, and before recurrent neural network. +- matrix width is: block_height * block_width * channels + +output_height = + 1 + (2 * padding_height + img_height - block_height + stride_height - 1) / + stride_height; +output_width = + 1 + (2 * padding_width + img_width - block_width + stride_width - 1) / + stride_width; + +After expanding, The number of time steps are output_height * output_width +and the dimension of each time step is block_height * block_width * channels. +This op can be used after convolution neural network, and before recurrent neural network. + +Given: + +x = [[[[ 6. 2. 1.] + [ 8. 3. 5.] + [ 0. 2. 6.]] + + [[ 2. 4. 4.] + [ 6. 3. 0.] + [ 6. 4. 7.]]] + + [[[ 6. 7. 1.] + [ 5. 7. 9.] + [ 2. 4. 8.]] + + [[ 1. 2. 1.] + [ 1. 3. 5.] + [ 9. 0. 8.]]]] +x.dims = {2, 2, 3, 3} + +And: + +block_height = 2 +block_width = 2 +stride_height = 1 +stride_width = 1 +padding_height = 0 +padding_width = 0 + +Then: + +output.data = [[ 6. 2. 8. 3. 2. 4. 6. 3.] + [ 2. 1. 3. 5. 4. 4. 3. 0.] + [ 8. 3. 0. 2. 6. 3. 6. 4.] + [ 3. 5. 2. 6. 3. 0. 4. 7.] + [ 6. 7. 5. 7. 1. 2. 1. 3.] + [ 7. 1. 7. 9. 2. 1. 3. 5.] + [ 5. 7. 2. 4. 1. 3. 9. 0.] + [ 7. 9. 4. 8. 3. 5. 0. 8.]] +output.dims = {8, 9} +output.lod = [[0, 4, 8]] + )DOC"); } }; diff --git a/paddle/operators/block_expand_op.h b/paddle/operators/block_expand_op.h index aa0db2705..022dc3a12 100644 --- a/paddle/operators/block_expand_op.h +++ b/paddle/operators/block_expand_op.h @@ -23,20 +23,9 @@ namespace paddle { namespace operators { -inline void get_blockexpand_output_shape(int img_height, int img_width, - int block_height, int block_width, - int stride_height, int stride_width, - int padding_height, int padding_width, - int& outputHeight, int& outputWidth) { - outputHeight = - 1 + - (img_height + 2 * padding_height - block_height + stride_height - 1) / - stride_height; - - outputWidth = - 1 + - (img_width + 2 * padding_width - block_width + stride_width - 1) / - stride_width; +inline int get_output_size(int img_size, int block_size, int stride, + int padding) { + return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride); } template @@ -45,40 +34,54 @@ class BlockExpandKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { using namespace framework; const Tensor* in = ctx.Input("X"); - Tensor* out = ctx.Output("Out"); + LoDTensor* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto in_dim = in->dims(); - int N = in_dim[0]; - int C = in_dim[1]; + int batch_size = in_dim[0]; + int img_channels = in_dim[1]; int img_height = in_dim[2]; int img_width = in_dim[3]; - - int block_height = ctx.Attr("blockHeight"); - int block_width = ctx.Attr("blockWidth"); - int stride_height = ctx.Attr("strideHeight"); - int stride_width = ctx.Attr("strideWidth"); - int padding_height = ctx.Attr("paddingHeight"); - int padding_width = ctx.Attr("paddingWidth"); - - int outputHeight = 0; - int outputWidth = 0; - - get_blockexpand_output_shape( - img_height, img_width, block_height, block_width, stride_height, - stride_width, padding_height, padding_width, outputHeight, outputWidth); - - std::vector stride({stride_height, stride_width}); - std::vector padding({padding_height, padding_width}); - - for (int i = 0; i < N; i++) { - Tensor src = in->Slice(i, i + 1).Resize({C, img_height, img_width}); - Tensor dst = out->Slice(i, i + 1).Resize( - {outputHeight, outputWidth, C, block_height, block_width}); + int block_height = ctx.Attr("block_height"); + int block_width = ctx.Attr("block_width"); + int stride_height = ctx.Attr("stride_height"); + int stride_width = ctx.Attr("stride_width"); + int padding_height = ctx.Attr("padding_height"); + int padding_width = ctx.Attr("padding_width"); + + int output_height = get_output_size(img_height, block_height, stride_height, + padding_height); + int output_width = + get_output_size(img_width, block_width, stride_width, padding_width); + + const std::vector dilations({1, 1}); + const std::vector strides( + {stride_height, stride_width, stride_height, stride_width}); + const std::vector paddings( + {padding_height, padding_width, padding_height, padding_width}); + + auto out_dims = out->dims(); + out->Resize({batch_size, out->numel() / batch_size}); + for (int i = 0; i < batch_size; i++) { + const Tensor src = + in->Slice(i, i + 1).Resize({img_channels, img_height, img_width}); + Tensor dst = out->Slice(i, i + 1).Resize({output_height, output_width, + img_channels, block_height, + block_width}); math::Im2ColFunctor f; - f(ctx.device_context(), src, stride, padding, &dst); + f(ctx.device_context(), src, dilations, strides, paddings, &dst); } + out->Resize(out_dims); + + // set lod information + // TODO(wanghaoshuang): Move this to InferShape + framework::LoD lod(1); + for (int i = 0, offset = 0; i < batch_size + 1; ++i) { + lod[0].push_back(offset); + offset += output_height * output_width; + } + out->set_lod(lod); } }; @@ -88,7 +91,8 @@ class BlockExpandGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { using namespace framework; auto* in = ctx.Input("X"); - auto* d_out = ctx.Input(framework::GradVarName("Out")); + Tensor* d_out = + const_cast(ctx.Input(framework::GradVarName("Out"))); auto* d_x = ctx.Output(GradVarName("X")); d_x->mutable_data(ctx.GetPlace()); @@ -96,36 +100,40 @@ class BlockExpandGradKernel : public framework::OpKernel { x_v.device(ctx.GetEigenDevice()) = x_v.constant(0.0); auto in_dim = in->dims(); - int N = in_dim[0]; - int C = in_dim[1]; + int batch_size = in_dim[0]; + int img_channels = in_dim[1]; int img_height = in_dim[2]; int img_width = in_dim[3]; - int block_height = ctx.Attr("blockHeight"); - int block_width = ctx.Attr("blockWidth"); - int stride_height = ctx.Attr("strideHeight"); - int stride_width = ctx.Attr("strideWidth"); - int padding_height = ctx.Attr("paddingHeight"); - int padding_width = ctx.Attr("paddingWidth"); - - int outputHeight = 0; - int outputWidth = 0; - - get_blockexpand_output_shape( - img_height, img_width, block_height, block_width, stride_height, - stride_width, padding_height, padding_width, outputHeight, outputWidth); - - std::vector stride({stride_height, stride_width}); - std::vector padding({padding_height, padding_width}); - // std::vector stride({stride_height, stride_width}); - - for (int i = 0; i < N; i++) { - Tensor dst = d_x->Slice(i, i + 1).Resize({C, img_height, img_width}); - Tensor src = d_out->Slice(i, i + 1).Resize( - {outputHeight, outputWidth, C, block_height, block_width}); + int block_height = ctx.Attr("block_height"); + int block_width = ctx.Attr("block_width"); + int stride_height = ctx.Attr("stride_height"); + int stride_width = ctx.Attr("stride_width"); + int padding_height = ctx.Attr("padding_height"); + int padding_width = ctx.Attr("padding_width"); + int output_height = get_output_size(img_height, block_height, stride_height, + padding_height); + int output_width = + get_output_size(img_width, block_width, stride_width, padding_width); + + const std::vector dilations({1, 1}); + const std::vector strides( + {stride_height, stride_width, stride_height, stride_width}); + const std::vector paddings( + {padding_height, padding_width, padding_height, padding_width}); + + auto d_out_dims = d_out->dims(); + d_out->Resize({batch_size, d_out->numel() / batch_size}); + for (int i = 0; i < batch_size; i++) { + Tensor dst = + d_x->Slice(i, i + 1).Resize({img_channels, img_height, img_width}); + const Tensor src = d_out->Slice(i, i + 1).Resize( + {output_height, output_width, img_channels, block_height, + block_width}); math::Col2ImFunctor f; - f(ctx.device_context(), dst, stride, padding, &src); + f(ctx.device_context(), src, dilations, strides, paddings, &dst); } + d_out->Resize(d_out_dims); } }; diff --git a/python/paddle/v2/fluid/tests/test_block_expand_op.py b/python/paddle/v2/fluid/tests/test_block_expand_op.py index b31ed53f4..424bc7dc6 100644 --- a/python/paddle/v2/fluid/tests/test_block_expand_op.py +++ b/python/paddle/v2/fluid/tests/test_block_expand_op.py @@ -4,20 +4,20 @@ from op_test import OpTest def get_output_shape(attrs, x): - img_height = x.shape[1] - img_width = x.shape[2] + img_height = x.shape[2] + img_width = x.shape[3] - padding_height = attrs['paddingHeight'] - padding_width = attrs['paddingWidth'] - block_height = attrs['blockHeight'] - block_width = attrs['blockWidth'] - stride_height = attrs['strideHeight'] - stride_width = attrs['strideWidth'] + padding_height = attrs['padding_height'] + padding_width = attrs['padding_width'] + block_height = attrs['block_height'] + block_width = attrs['block_width'] + stride_height = attrs['stride_height'] + stride_width = attrs['stride_width'] output_height = \ 1 + \ (img_height + 2 * padding_height - block_height + stride_height - 1) / \ - strideHeight + stride_height output_width = \ 1 + \ @@ -42,10 +42,10 @@ def im2col(attrs, im, col): filter_height = col.shape[3] filter_width = col.shape[4] - stride_height = attrs['strideHeight'] - stride_width = attrs['strideWidth'] - padding_height = attrs['paddingHeight'] - padding_width = attrs['paddingWidth'] + stride_height = attrs['stride_height'] + stride_width = attrs['stride_width'] + padding_height = attrs['padding_height'] + padding_width = attrs['padding_width'] for col_row_idx in range(0, output_height): for col_col_idx in range(0, output_width): @@ -73,83 +73,51 @@ def im2col(attrs, im, col): im_row_offset][im_col_offset] -def col2img(attrs, col, img): - """ - img: {CHW} - col: - {output_height, outputWidth, inputChannels, filterHeight, filterWidth} - """ - input_channels = im.shape[0] - input_height = im.shape[1] - input_width = im.shape[2] - - output_height = col.shape[0] - output_width = col.shape[1] - filter_height = col.shape[3] - filter_width = col.shape[4] +def block_expand(inputs, attrs): + output_height, output_width = get_output_shape(attrs, inputs) + img_channels = inputs.shape[1] + batch_size = inputs.shape[0] + out = np.zeros([ + batch_size, output_height, output_width, img_channels, + attrs['block_height'], attrs['block_width'] + ]).astype("float32") - stride_height = attrs['strideHeight'] - stride_width = attrs['strideWidth'] - padding_height = attrs['paddingHeight'] - padding_width = attrs['paddingWidth'] + for i in range(len(inputs)): + im2col(attrs, inputs[i], out[i]) - for col_row_idx in range(0, output_height): - for col_col_idx in range(0, output_width): - for channel in range(0, input_channels): - for filter_row_idx in range(0, filter_height): - for filter_col_idx in range(0, filter_width): - im_row_offset = \ - col_row_idx * stride_height + filter_row_idx - padding_height - im_col_offset = \ - col_col_idx * stride_width + filter_col_idx - padding_width - if (im_row_offset >= 0 and - im_row_offset < input_height and - im_col_offset >= 0 and - im_col_offset < input_width): - im[channel][im_row_offset][im_col_offset] = \ - col[col_row_idx][col_col_idx][channel][filter_row_idx][filter_col_idx] - - -def get_input_data(C, H, W): - x = np.random.uniform(0.1, 1, [C, H, W]).astype("float32") - for c in range(0, C): - for h in range(0, H): - for w in range(0, W): - #x[c][h][w] = c * H * W + h *W + w - x[c][h][w] = 0.2 + 0.01 * (c * H * W + h * W + w) - return x + out = out.reshape([ + batch_size * output_height * output_width, + img_channels * attrs['block_height'] * attrs['block_width'] + ]) + return out class TestBlockExpandOp(OpTest): - def setUp(self): - C = 3 - H = 4 - W = 4 - x = get_input_data(C, H, W) - - attrs = { - 'blockHeight': 2, - 'blockWidth': 2, - 'strideHeight': 1, - 'strideWidth': 1, - 'paddingHeight': 1, - 'paddingWidth': 1, + def config(self): + self.batch_size = 1 + self.img_channels = 3 + self.img_height = 4 + self.img_width = 4 + self.attrs = { + 'block_height': 2, + 'block_width': 2, + 'stride_height': 1, + 'stride_width': 1, + 'padding_height': 1, + 'padding_width': 1, } - output_height, output_width = get_output_shape(attrs, x) - out = np.random.uniform(0.1, 1,\ - [output_height, output_width, x.shape[0], \ - attrs['blockHeight'], attrs['blockWidth']]).astype("float32") - + def setUp(self): + self.config() self.op_type = "block_expand" - self.inputs = {'X': x.reshape(1, C, H, W)} - self.attrs = attrs + #x = np.random.uniform(0.1, 1, + x = np.random.randint(0, 10, [ + self.batch_size, self.img_channels, self.img_height, self.img_width + ]).astype("float32") - im2col(attrs, x, out) - self.outputs = { - 'Out':out.reshape(1, output_height, output_width, x.shape[0], \ - attrs['blockHeight'], attrs['blockWidth']) - } + out = block_expand(x, self.attrs) + self.inputs = {'X': x} + self.outputs = {'Out': out} def test_check_output(self): self.check_output() @@ -158,42 +126,52 @@ class TestBlockExpandOp(OpTest): self.check_grad(['X'], 'Out') -class TestBlockExpandOp2(OpTest): - def setUp(self): - C = 3 - H = 4 - W = 5 - x = get_input_data(C, H, W) - - attrs = { - 'blockHeight': 2, - 'blockWidth': 1, - 'strideHeight': 2, - 'strideWidth': 1, - 'paddingHeight': 2, - 'paddingWidth': 1, +class TestBlockExpandOpCase2(TestBlockExpandOp): + def config(self): + self.batch_size = 2 + self.img_channels = 3 + self.img_height = 4 + self.img_width = 5 + self.attrs = { + 'block_height': 2, + 'block_width': 1, + 'stride_height': 2, + 'stride_width': 1, + 'padding_height': 2, + 'padding_width': 1, } - output_height, output_width = get_output_shape(attrs, x) - out = np.random.uniform(0.1, 1,\ - [output_height, output_width, x.shape[0], \ - attrs['blockHeight'], attrs['blockWidth']]).astype("float32") - - self.op_type = "block_expand" - self.inputs = {'X': x.reshape(1, C, H, W)} - self.attrs = attrs - im2col(attrs, x, out) - self.outputs = { - 'Out':out.reshape(1, output_height, output_width, x.shape[0], \ - attrs['blockHeight'], attrs['blockWidth']) - } +class TestBlockExpandOpCase3(TestBlockExpandOp): + def config(self): + self.batch_size = 3 + self.img_channels = 1 + self.img_height = 4 + self.img_width = 5 + self.attrs = { + 'block_height': 2, + 'block_width': 1, + 'stride_height': 2, + 'stride_width': 1, + 'padding_height': 2, + 'padding_width': 0, + } - def test_check_output(self): - self.check_output() - def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') +class TestBlockExpandOpCase4(TestBlockExpandOp): + def config(self): + self.batch_size = 2 + self.img_channels = 2 + self.img_height = 3 + self.img_width = 3 + self.attrs = { + 'block_height': 2, + 'block_width': 2, + 'stride_height': 1, + 'stride_width': 1, + 'padding_height': 0, + 'padding_width': 0, + } if __name__ == '__main__': -- GitLab