diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b8eef7c4a68177fec6715e8761fce4dede023530..56cc02c5c5c12eb289faeea9ddd8fd4da116d9a6 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -238,6 +238,7 @@ paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_c paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6')) paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'ccf6bb7912afd2818d24bc45461e807a')) paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', 'c896b66265a60bd3c5510f66e6e02919')) +paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '6e19128b46936edf9f3fad77860a1da8')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'dce69a78638da8f7ad80b1fc00ed2029')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6')) diff --git a/paddle/fluid/operators/unfold_op.cc b/paddle/fluid/operators/unfold_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d21340b478b590259b04ce66a3db129fdb50c7e7 --- /dev/null +++ b/paddle/fluid/operators/unfold_op.cc @@ -0,0 +1,184 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/unfold_op.h" + +namespace paddle { +namespace operators { + +class UnfoldOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "Tensor, " + "the input of unfold op. " + "The format of X is [N, C_in, H, W], " + "where N is the batch size, C_in is the input channels, " + "H is the height and W is the width"); + AddOutput( + "Y", + "Tensor, " + "the output of unfold op. " + "The format of Y is [N, C_in*filter_height*filter_width, " + "output_height*output_width], where N is the batch size, " + "C_in is the input channels of X, filter_height and filter_width is " + "height and width of the filtering kernel, output_height and " + "output_width " + "is the calculated height and width of output feature map."); + AddAttr>( + "kernel_sizes", + "vector, the kernel sizes of the convolution operator."); + AddAttr>( + "strides", "vector, the strides of the convolution operator."); + AddAttr>( + "paddings", + "vector, the paddings applied to pad the feature map."); + AddAttr>( + "dilations", "vector, the dilations of the convolution operator."); + AddComment(R"DOC( +**Unfold Operator** + +This Operator is used to extract sliding local blocks from a batched input tensor, also known +as im2col when operated on batched 2D image tensor. For each block under the convolution filter, +all element will be rearranged as a column. While the convolution filter silding over the input +feature map, a series of such columns will be formed. + )DOC"); + } +}; + +class UnfoldOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnfoldOp should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of UnfoldOp should not be null"); + auto in_dims = ctx->GetInputDim("X"); + std::vector kernel_sizes = + ctx->Attrs().Get>("kernel_sizes"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector dilations = + ctx->Attrs().Get>("dilations"); + + // Only [N, C, H, W] input supported now + PADDLE_ENFORCE( + in_dims.size() == 4, + "Input shold be 4-D tensor of format [N, C, H, W], but get %u", + in_dims.size()); + PADDLE_ENFORCE( + in_dims.size() - kernel_sizes.size() == 2U, + "The dims of X should be larger than that of kernel_sizes " + "by a number of 2, due to the batch size and input channel dim. " + "But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2", + in_dims.size(), kernel_sizes.size()); + PADDLE_ENFORCE_EQ( + strides.size(), kernel_sizes.size(), + "The dims of strides shold be the same with that of kernel_sizes. " + "But recieved dims(strides: %u) != dims(kernel_sizes: %u).", + strides.size(), kernel_sizes.size()); + PADDLE_ENFORCE_EQ( + paddings.size(), 2 * strides.size(), + "The dims of paddings should be 2 times of that of strides. " + "But recieved dims(paddings: %u) != 2*dims(strides: %u).", + paddings.size(), strides.size()); + PADDLE_ENFORCE_EQ( + strides.size(), dilations.size(), + "The dims of strides shold be the same with that of dilations. " + "But recieved dims(strides: %u) != dims(dilations: %u).", + strides.size(), dilations.size()); + + std::vector out_dims; + out_dims.push_back(in_dims[0]); + + int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1]; + out_dims.push_back(output_channels); + + int output_height = + CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0], + paddings[2], strides[0]); + int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], + paddings[1], paddings[3], strides[1]); + int output_col_length = output_height * output_width; + out_dims.push_back(output_col_length); + + ctx->SetOutputDim("Y", framework::make_ddim(out_dims)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class UnfoldGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "The gradient of Y should not be null"); + PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "The gradient of X should not be null"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Y"))->type(), + ctx.device_context()); + } +}; + +class UnfoldGradDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("unfold_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetInput("X", Input("X")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(UnfoldGradOpNoNeedBufferVarsInference, + "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unfold, ops::UnfoldOp, ops::UnfoldOpMaker, + ops::UnfoldGradDescMaker); +REGISTER_OPERATOR(unfold_grad, ops::UnfoldGradOp, + ops::UnfoldGradOpNoNeedBufferVarsInference); + +REGISTER_OP_CPU_KERNEL( + unfold, ops::UnfoldOpKernel, + ops::UnfoldOpKernel); +REGISTER_OP_CPU_KERNEL( + unfold_grad, + ops::UnfoldGradOpKernel, + ops::UnfoldGradOpKernel); diff --git a/paddle/fluid/operators/unfold_op.cu b/paddle/fluid/operators/unfold_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..46584506d431564cfc7af11072eee6c544f03564 --- /dev/null +++ b/paddle/fluid/operators/unfold_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unfold_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + unfold, ops::UnfoldOpKernel, + ops::UnfoldOpKernel); + +REGISTER_OP_CUDA_KERNEL( + unfold_grad, + ops::UnfoldGradOpKernel, + ops::UnfoldGradOpKernel); diff --git a/paddle/fluid/operators/unfold_op.h b/paddle/fluid/operators/unfold_op.h new file mode 100644 index 0000000000000000000000000000000000000000..97e8143bc052b346a85b50ab26bb8563e48f30d9 --- /dev/null +++ b/paddle/fluid/operators/unfold_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/im2col.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +inline int CalcOutputSize(int input_size, int filter_size, int dilation, + int padding1, int padding2, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + padding1 + padding2 - dkernel) / stride + 1; + PADDLE_ENFORCE(output_size > 0, + "Due to the settings of padding(%d, %d), filter_size(%d), " + "dilation(%d) and " + "stride(%d), the output size is less than 0, please check " + "again. Input_size:%d", + padding1, padding2, filter_size, dilation, stride, input_size); + + return output_size; +} + +template +class UnfoldOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + Tensor* output = ctx.Output("Y"); + output->mutable_data(ctx.GetPlace()); + + std::vector kernel_sizes = ctx.Attr>("kernel_sizes"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + math::Im2ColFunctor im2col; + auto& dev_ctx = ctx.template device_context(); + + auto input_dims = input->dims(); + + int output_height = + CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0], + paddings[0], paddings[2], strides[0]); + int output_width = + CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1], + paddings[1], paddings[3], strides[1]); + + framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]}); + framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0], + kernel_sizes[1], output_height, + output_width}); + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + im2col(dev_ctx, in_batch, dilations, strides, paddings, &out_batch); + } + } +}; + +template +class UnfoldGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = ctx.Input(framework::GradVarName("Y")); + Tensor* input_grad = ctx.Output(framework::GradVarName("X")); + input_grad->mutable_data(ctx.GetPlace()); + + if ((!output_grad) || (!input_grad)) return; + + std::vector kernel_sizes = ctx.Attr>("kernel_sizes"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + const int batch_size = static_cast(input_grad->dims()[0]); + + auto input_dims = input_grad->dims(); + + int output_height = + CalcOutputSize(input_dims[2], kernel_sizes[0], dilations[0], + paddings[0], paddings[2], strides[0]); + int output_width = + CalcOutputSize(input_dims[3], kernel_sizes[1], dilations[1], + paddings[1], paddings[3], strides[1]); + + framework::DDim input_shape({input_dims[1], input_dims[2], input_dims[3]}); + framework::DDim output_matrix_shape({input_dims[1], kernel_sizes[0], + kernel_sizes[1], output_height, + output_width}); + + math::Col2ImFunctor col2im; + auto& dev_ctx = ctx.template device_context(); + + math::SetConstant set_zero; + set_zero(dev_ctx, input_grad, static_cast(0)); + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + col2im(dev_ctx, out_grad_batch, dilations, strides, paddings, + &in_grad_batch); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 44b6794f34e8ea540ada0103671d4a08ff374c3c..780043b469bb22eb90893256583991b99dc25711 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -203,6 +203,7 @@ __all__ = [ 'where', 'sign', 'deformable_conv', + 'unfold', ] kIgnoreIndex = -100 @@ -12057,3 +12058,113 @@ def deformable_conv(input, output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) return output + + +def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None): + """ + + This function returns a col buffer of sliding local blocks of input x, also known + as im2col for batched 2D image tensors. For each block under the convolution filter, + all element will be rearranged as a column. While the convolution filter silding over + the input feature map, a series of such columns will be formed. + + For each input :math:`X` with shape [N, C, H, W], the output shape [N, Cout, Lout] + can be calculated as following. + + .. math:: + + dkernel[0] &= dilations[0] \\times (kernel\_sizes[0] - 1) + 1 + + dkernel[1] &= dilations[1] \\times (kernel\_sizes[1] - 1) + 1 + + hout &= \\frac{H + paddings[0] + paddings[2] - dkernel[0]}{strides[0]} + 1 + + wout &= \\frac{W + paddings[1] + paddings[3] - dkernel[1]}{strides[1]} + 1 + + Cout &= C \\times kernel\_sizes[0] \\times kernel\_sizes[1] + + Lout &= hout \\times wout + + + Args: + x(Varaible): The input tensor of format [N, C, H, W]. + kernel_sizes(int|list): The size of convolution kernel, should be [k_h, k_w] + or an integer k treated as [k, k]. + strides(int|list): The strides, should be [stride_h, stride_w] + or an integer stride treated as [sride, stride]. + For default, strides will be [1, 1]. + paddings(int|list): The paddings of each dimension, should be + [padding_top, padding_left, padding_bottom, padding_right] + or [padding_h, padding_w] or an integer padding. + If [padding_h, padding_w] was given, it will expanded to + [padding_h, padding_w, padding_h, padding_w]. If an integer + padding was given, [padding, padding, padding, padding] will + be used. For default, paddings will be [0, 0, 0, 0] + dilations(int|list): the dilations of convolution kernel, shold be + [dilation_h, dilation_w], or an integer dialtion treated as + [dilation, dilation]. For default, it will be [1, 1]. + + + Returns: + Variable: The tensor variable corresponding to the sliding local blocks. The output shape is [N, Cout, Lout] as decribled above. Cout is the total number of values within each block, and Lout is the total number of such blocks. + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + x = fluid.layers.data(name = 'data', shape = [3, 224, 224], dtype = 'float32') + y = fluid.layers.unfold(x, [3, 3], 1, 1, 1) + """ + + helper = LayerHelper("unfold", **locals()) + + assert len(x.shape) == 4, \ + "input should be the format of [N, C, H, W]" + + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes, kernel_sizes] + else: + assert isinstance(kernel_sizes, list) and (len(kernel_sizes) == 2), \ + "kernel_sizes should either be an integer or a list of two integers" + + if isinstance(strides, int): + strides = [strides, strides] + else: + assert isinstance(strides, list) and (len(strides) == 2), \ + "strides should either be an integer or a list of two integers" + + if isinstance(dilations, int): + dilations = [dilations, dilations] + else: + assert isinstance(dilations, list) and (len(dilations) == 2), \ + "dilations should either be an integer or a list of two integers" + + if isinstance(paddings, int): + paddings = [paddings] * 4 + elif isinstance(paddings, list): + if len(paddings) == 2: + paddings = paddings * 2 + elif len(paddings) == 4: + pass + else: + raise ValueError( + "paddings should either be an integer or a list of 2 or 4 integers" + ) + else: + raise ValueError( + "Unexpected type of paddings, it should be either an integer or a list" + "of 2 or 4 integers") + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="unfold", + inputs={"X": x}, + outputs={"Y": out}, + attrs={ + "kernel_sizes": kernel_sizes, + "strides": strides, + "paddings": paddings, + "dilations": dilations + }) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 2204ea21c032e065f8739f87815cee82aa75f195..4359e39a8556cb03651c2c569dbf3457c9733a9a 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1989,6 +1989,12 @@ class TestBook(LayerTest): padding=1) return (out) + def test_unfold(self): + with self.static_graph(): + x = layers.data(name='x', shape=[3, 20, 20], dtype='float32') + out = layers.unfold(x, [3, 3], 1, 1, 1) + return (out) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unfold_op.py b/python/paddle/fluid/tests/unittests/test_unfold_op.py new file mode 100644 index 0000000000000000000000000000000000000000..379982b60682c166ddb737b6a009d1ea758c0729 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unfold_op.py @@ -0,0 +1,102 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import math +import numpy as np +import unittest +from op_test import OpTest + + +class TestUnfoldOp(OpTest): + """ + This is for test on unfold Op + """ + + def init_data(self): + self.batch_size = 3 + self.input_channels = 3 + self.input_height = 20 + self.input_width = 20 + self.kernel_sizes = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1, 1, 1] + self.dilations = [1, 1] + input_shape = [ + self.batch_size, self.input_channels, self.input_height, + self.input_width + ] + self.x = np.random.rand(*input_shape).astype(np.float32) + + def calc_unfold(self): + output_shape = [0] * 3 + output_shape[0] = self.batch_size + output_shape[1] = self.input_channels * self.kernel_sizes[ + 0] * self.kernel_sizes[1] + dkernel_h = self.dilations[0] * (self.kernel_sizes[0] - 1) + 1 + dkernel_w = self.dilations[1] * (self.kernel_sizes[1] - 1) + 1 + out_height = int((self.input_height + self.paddings[0] + + self.paddings[2] - dkernel_h) / self.strides[0]) + 1 + out_width = int((self.input_width + self.paddings[1] + self.paddings[3] + - dkernel_w) / self.strides[1]) + 1 + output_shape[2] = out_height * out_width + output = np.zeros(output_shape).astype(np.float32) + ############ calculate output ############## + for i in range(output_shape[0]): + for j in range(output_shape[1]): + for k in range(output_shape[2]): + h_out = int(k / out_width) + w_out = k % out_width + w_offset = j % self.kernel_sizes[1] + h_offset = int(j / + self.kernel_sizes[1]) % self.kernel_sizes[0] + c_in = int(j / + (self.kernel_sizes[0] * self.kernel_sizes[1])) + h_in = h_offset * self.dilations[0] + h_out * self.strides[ + 0] - self.paddings[0] + w_in = w_offset * self.dilations[1] + w_out * self.strides[ + 1] - self.paddings[1] + if (h_in>=0 and h_in=0 and w_in