diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d8169fd2deb8c65bde6596fcbd9d8c1e61634263..2b102bd6804f29459157f67388f1eccc9e0d05eb 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -267,6 +267,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e')) paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6')) paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35')) +paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c')) paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cccb6eb5410c822e5307c947aca2c899')) diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index 4ec0a35b2900a17f55428bb0e2cea3c9aa69c620..6ac5f994222ec56775b38f8d077579b825419f22 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -41,3 +41,4 @@ tensor_array_to_tensor transpose unpool unsqueeze +var_conv_2d diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 98ff3ea14659634535fcdbfe4f33c663e6dfbc2f..3770d30831c550bade7d2b1b294f2704b6b8fd22 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -48,8 +48,13 @@ if (WITH_DISTRIBUTE) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) endif() +SET(OP_ONLY_MKL "") +if (NOT WITH_MKL) + SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op) +endif() + register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op - sync_batch_norm_op deformable_conv_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) + sync_batch_norm_op deformable_conv_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) if (WITH_GPU) # warpctc_op needs cudnn 7 above diff --git a/paddle/fluid/operators/var_conv_2d_op.cc b/paddle/fluid/operators/var_conv_2d_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..232075203a0705ba5c68c80bae7cbf4613cbb970 --- /dev/null +++ b/paddle/fluid/operators/var_conv_2d_op.cc @@ -0,0 +1,431 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/var_conv_2d_op.h" +#include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/dynload/mklml.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +void VarConv2dOpMaker::Make() { + AddInput("X", + "X (LoDTensor, default LoDTensor) Input variable which " + "should contain lod information."); + AddInput("ROW", "(LoDTensor) the row variable provides lod information"); + AddInput("COLUMN", + "(LoDTensor) the column variable provides lod information"); + AddInput("W", "W (Tensor), the filter."); + AddAttr("InputChannel", "the input filter num").SetDefault(1); + AddAttr("OutputChannel", "the output filter num").SetDefault(1); + AddAttr("StrideH", "the height of Stride").SetDefault(1); + AddAttr("StrideW", "the width of Stride").SetDefault(1); + AddAttr("KernelH", "the height of Kernel").SetDefault(1); + AddAttr("KernelW", "the width of Kernel").SetDefault(1); + + AddOutput("Out", "(LoDTensor, default LoDTensor) Output variable"); + AddOutput("Col", + "(LoDTensor, default LoDTensor) the intermediate result " + "variable"); + + AddComment(R"DOC( + Var Size Conv Operator + + This operator calculate Out = \sigma \left ( W * X + b \right ), + only support 2-D for X. + + NOTE: only support 'float32' data type now. + + )DOC"); +} + +void VarConv2dOP::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of VarConv2dOP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), + "W(Input) of VarConv2dOP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ROW"), + "Input(ROW) of VarConv2dOP should not be null."); + PADDLE_ENFORCE(ctx->HasInput("COLUMN"), + "Input(COLUMN) of VarConv2dOP should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of VarConv2dOP should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Col"), + "Col(Output) of VarConv2dOP should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, + "The rank of X(Input) can't be less than 2."); + + auto w_dims = ctx->GetInputDim("W"); + + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "W should be 2-D tensor"); + int output_channel = ctx->Attrs().Get("OutputChannel"); + int input_channel = ctx->Attrs().Get("InputChannel"); + int kernel_h = ctx->Attrs().Get("KernelH"); + int kernel_w = ctx->Attrs().Get("KernelW"); + PADDLE_ENFORCE_EQ(w_dims[0], output_channel, + "W dim[0] should be equal to OutputChannel"); + PADDLE_ENFORCE_EQ( + w_dims[1], input_channel * kernel_h * kernel_w, + "W dim[1] should be equal to InputChannel * StrideH * StrideW"); + + if (ctx->IsRuntime()) { + framework::Variable* x_var = + boost::get(ctx->GetInputVarPtrs("X")[0]); + const auto& x_lod = x_var->Get().lod(); + PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info."); + + PADDLE_ENFORCE_GE(x_lod.size(), 1, "The Input(X)'s lod info is corrupted."); + PADDLE_ENFORCE_EQ( + x_dims[0], static_cast(x_lod[0].back()), + "The Input(X)'s lod info mismatches the actual tensor shape."); + + framework::Variable* row_var = + boost::get(ctx->GetInputVarPtrs("ROW")[0]); + const auto& row_lod = row_var->Get().lod(); + PADDLE_ENFORCE(!row_lod.empty(), "The Input(ROW) must hold lod info."); + + framework::Variable* col_var = + boost::get(ctx->GetInputVarPtrs("COLUMN")[0]); + const auto& col_lod = col_var->Get().lod(); + PADDLE_ENFORCE(!col_lod.empty(), "The Input(COLUMN) must hold lod info."); + } else { + std::vector out_dims_vec{-1}; + out_dims_vec.push_back(1); + std::vector col_dims_vec{-1}; + col_dims_vec.push_back(1); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); + ctx->SetOutputDim("Col", framework::make_ddim(col_dims_vec)); + } +} + +template +class CPUVarConv2dOPKernel : public framework::OpKernel { + public: + void Im2Col(const framework::ExecutionContext& ctx, const LoDTensor& input, + LoDTensor* col) const { + int input_channel = ctx.Attr("InputChannel"); + auto* in_row = ctx.Input("ROW"); + auto* in_col = ctx.Input("COLUMN"); + int kernel_h = ctx.Attr("KernelH"); + int kernel_w = ctx.Attr("KernelW"); + int stride_h = ctx.Attr("StrideH"); + int stride_w = ctx.Attr("StrideW"); + + int batch = input.lod()[0].size() - 1; + const auto& bottom_offset = input.lod()[0]; + // 2-D lod info. + const auto& offset_x = in_col->lod()[0]; + const auto& offset_y = in_row->lod()[0]; + + // top offset is the whole size of each data sample + std::vector top_offset; + int top_size = 0; + top_offset.push_back(top_size); + for (int b = 0; b < batch; ++b) { + int width = offset_x[b + 1] - offset_x[b]; + int height = offset_y[b + 1] - offset_y[b]; + int top_im_x = 0; + if (width == 0) { + top_im_x = 0; + } else { + top_im_x = (width - 1) / stride_w + 1; + } + int top_im_y = 0; + if (height == 0) { + top_im_y = 0; + } else { + top_im_y = (height - 1) / stride_h + 1; + } + int top_x = top_im_y * top_im_x; + int top_y = input_channel * kernel_h * kernel_w; + top_size += top_y * top_x; + top_offset.push_back(top_size); + } + framework::LoD col_lod; + col_lod.push_back(top_offset); + col->set_lod(col_lod); + std::vector col_dims_vec{top_size}; + col_dims_vec.push_back(1); + auto* top_data = col->mutable_data(framework::make_ddim(col_dims_vec), + ctx.GetPlace()); + auto* bottom_data = input.data(); + + int kernel_win_size = kernel_h * kernel_w; + int half_kernel_h = kernel_h / 2; + int half_kernel_w = kernel_w / 2; + for (int b = 0; b < batch; ++b) { + int t_offset = top_offset[b]; + int b_offset = bottom_offset[b]; + int width = offset_x[b + 1] - offset_x[b]; + int height = offset_y[b + 1] - offset_y[b]; + if (width == 0 || height == 0) { + continue; + } + int top_im_x = (width - 1) / stride_w + 1; + int top_im_y = (height - 1) / stride_h + 1; + int top_x = top_im_y * top_im_x; + for (int z = 0; z < input_channel; ++z) { + int row_offset = kernel_win_size * z; + int im_offset = z * width * height; + for (int y = 0; y < height; y += stride_h) { + for (int x = 0; x < width; x += stride_w) { + int col_offset = x / stride_w + y / stride_h * top_im_x; + for (int ky = 0; ky < kernel_h; ++ky) { + for (int kx = 0; kx < kernel_w; ++kx) { + int im_y = y + ky - half_kernel_h; + int im_x = x + kx - half_kernel_w; + if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) { + top_data[t_offset + + (row_offset + ky * kernel_w + kx) * top_x + + col_offset] = + bottom_data[b_offset + im_offset + im_y * width + im_x]; + } else { + top_data[t_offset + + (row_offset + ky * kernel_w + kx) * top_x + + col_offset] = 0; + } + } + } + } + } + } + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto* bottom = ctx.Input("X"); + auto* in_row = ctx.Input("ROW"); + auto* in_col = ctx.Input("COLUMN"); + auto* w = ctx.Input("W"); + auto* top = ctx.Output("Out"); + auto* col = ctx.Output("Col"); + + int output_channel = ctx.Attr("OutputChannel"); + int input_channel = ctx.Attr("InputChannel"); + int kernel_h = ctx.Attr("KernelH"); + int kernel_w = ctx.Attr("KernelW"); + int stride_h = ctx.Attr("StrideH"); + int stride_w = ctx.Attr("StrideW"); + + Im2Col(ctx, *bottom, col); + int batch = bottom->lod()[0].size() - 1; + const auto& col_offset = col->lod()[0]; + const auto& offset_x = in_col->lod()[0]; + const auto& offset_y = in_row->lod()[0]; + std::vector top_offset; + int top_size = 0; + top_offset.push_back(top_size); + for (int b = 0; b < batch; ++b) { + int width = offset_x[b + 1] - offset_x[b]; + int height = offset_y[b + 1] - offset_y[b]; + int top_im_x = 0; + if (width == 0) { + top_im_x = 0; + } else { + top_im_x = (width - 1) / stride_w + 1; + } + int top_im_y = 0; + if (height == 0) { + top_im_y = 0; + } else { + top_im_y = (height - 1) / stride_h + 1; + } + int top_im_size = top_im_y * top_im_x; + top_size += output_channel * top_im_size; + top_offset.push_back(top_size); + } + + framework::LoD top_lod; + top_lod.push_back(top_offset); + + top->set_lod(top_lod); + std::vector top_dims_vec{top_size}; + top_dims_vec.push_back(1); + auto* top_data = top->mutable_data(framework::make_ddim(top_dims_vec), + ctx.GetPlace()); + + auto* w_data = w->data(); + auto* col_data = col->data(); + + auto blas = math::GetBlas(ctx); + for (int b = 0; b < batch; ++b) { + int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; + if (top_im_size == 0) { + continue; + } + + blas.GEMM(CblasNoTrans, CblasNoTrans, output_channel, top_im_size, + input_channel * kernel_h * kernel_w, 1.0, w_data, + col_data + col_offset[b], 0.0, top_data + top_offset[b]); + } + } +}; + +void VarConv2dOpGrad::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequencePadGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), + "Input(W) of SequencePadGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequencePadGradOp should not be null."); + + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } + if (ctx->HasOutput(framework::GradVarName("W"))) { + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + } +} + +template +class CPUVarConv2dOPGradKernel : public framework::OpKernel { + public: + void Im2ColGrad(const framework::ExecutionContext& ctx, T* top_diff) const { + auto* x = ctx.Input("X"); + auto* in_row = ctx.Input("ROW"); + auto* in_col = ctx.Input("COLUMN"); + auto* col = ctx.Input("Col"); + + int input_channel = ctx.Attr("InputChannel"); + int kernel_h = ctx.Attr("KernelH"); + int kernel_w = ctx.Attr("KernelW"); + int stride_h = ctx.Attr("StrideH"); + int stride_w = ctx.Attr("StrideW"); + + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + memset(dx_data, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T)); + + const auto& bottom_offset = x->lod()[0]; + const auto& offset_x = in_col->lod()[0]; + const auto& offset_y = in_row->lod()[0]; + const auto& top_offset = col->lod()[0]; + int batch = x->lod()[0].size() - 1; + int kernel_win_size = kernel_h * kernel_w; + int half_kernel_h = kernel_h / 2; + int half_kernel_w = kernel_w / 2; + for (int b = 0; b < batch; ++b) { + int t_offset = top_offset[b]; + int b_offset = bottom_offset[b]; + int width = offset_x[b + 1] - offset_x[b]; + int height = offset_y[b + 1] - offset_y[b]; + if (width == 0 || height == 0) { + continue; + } + int top_im_x = (width - 1) / stride_w + 1; + int top_im_y = (height - 1) / stride_h + 1; + int top_x = top_im_y * top_im_x; + for (int z = 0; z < input_channel; ++z) { + int row_offset = kernel_win_size * z; + int im_offset = z * width * height; + for (int y = 0; y < height; y += stride_h) { + for (int x = 0; x < width; x += stride_w) { + int col_offset = x / stride_w + y / stride_h * top_im_x; + for (int ky = 0; ky < kernel_h; ++ky) { + for (int kx = 0; kx < kernel_w; ++kx) { + int im_y = y + ky - half_kernel_h; + int im_x = x + kx - half_kernel_w; + if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) { + dx_data[b_offset + im_offset + im_y * width + im_x] += + top_diff[t_offset + + (row_offset + ky * kernel_w + kx) * top_x + + col_offset]; + } + } + } + } + } + } + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* w = ctx.Input("W"); + auto* col = ctx.Input("Col"); + auto* out = ctx.Input("Out"); + + int output_channel = ctx.Attr("OutputChannel"); + int input_channel = ctx.Attr("InputChannel"); + int kernel_h = ctx.Attr("KernelH"); + int kernel_w = ctx.Attr("KernelW"); + + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* d_w = ctx.Output(framework::GradVarName("W")); + + Tensor col_grad; + col_grad.Resize(col->dims()); + auto* col_diff = col_grad.mutable_data(ctx.GetPlace()); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* w_diff = d_w->mutable_data(ctx.GetPlace()); + + memset(dx_data, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T)); + memset(w_diff, 0.0, w->dims()[0] * w->dims()[1] * sizeof(T)); + memset(col_diff, 0.0, col->dims()[0] * col->dims()[1] * sizeof(T)); + auto* top_diff = d_out->data(); + auto* w_data = w->data(); + auto* col_data = col->data(); + int batch = x->lod()[0].size() - 1; + const auto& top_offset = out->lod()[0]; + const auto& col_offset = col->lod()[0]; + auto blas = math::GetBlas(ctx); + for (int b = 0; b < batch; ++b) { + int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; + if (top_im_size == 0) { + continue; + } + + blas.GEMM(CblasTrans, CblasNoTrans, input_channel * kernel_h * kernel_w, + top_im_size, output_channel, 1.0, w_data, + top_diff + top_offset[b], 1.0, col_diff + col_offset[b]); + + blas.GEMM(CblasNoTrans, CblasTrans, output_channel, + input_channel * kernel_h * kernel_w, top_im_size, 1.0, + top_diff + top_offset[b], col_data + col_offset[b], 1.0, + w_diff); + } + Im2ColGrad(ctx, col_diff); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plt = paddle::platform; +namespace frm = paddle::framework; +REGISTER_OPERATOR(var_conv_2d, ops::VarConv2dOP, ops::VarConv2dOpMaker, + frm::DefaultGradOpDescMaker); +REGISTER_OPERATOR(var_conv_2d_grad, ops::VarConv2dOpGrad); + +REGISTER_OP_CPU_KERNEL(var_conv_2d, + ops::CPUVarConv2dOPKernel); +// ops::CPUVarConv2dOPKernel +REGISTER_OP_CPU_KERNEL( + var_conv_2d_grad, + ops::CPUVarConv2dOPGradKernel); +// ops::CPUVarConv2dOPGradKernel diff --git a/paddle/fluid/operators/var_conv_2d_op.h b/paddle/fluid/operators/var_conv_2d_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b8d5de060934fa7ad5157c3718ddf0cc85771870 --- /dev/null +++ b/paddle/fluid/operators/var_conv_2d_op.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +class VarConv2dOP : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class VarConv2dOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class VarConv2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 53fb3bad8555b8fb6b2e8e12e769aa939e8cf859..8c1abaf321769f1974848ceb1511dd0d24844fdf 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -210,6 +210,7 @@ __all__ = [ 'deformable_conv', 'unfold', 'deformable_roi_pooling', + 'var_conv_2d', 'shard_index', ] @@ -12729,6 +12730,121 @@ def deformable_roi_pooling(input, return output +def var_conv_2d(input, + row, + col, + input_channel, + output_channel, + filter_size, + stride=1, + param_attr=None, + act=None, + dtype='float32', + name=None): + """ + The var_conv_2d layer calculates the output base on the :attr:`input` with variable length, + row, col, input channel, filter size and strides. Both :attr:`input`, :attr:`row`, + and :attr:`col` are 1-level LodTensor. The covolution operation is same as conv2d layer with + padding. Besides, input.dims[1] should be 1. + + .. code-block:: text + + If input_channel is 2 and given row lodTensor and col lodTensor as follows: + row.lod = [[5, 4]] + col.lod = [[6, 7]] + input is a lodTensor: + input.lod = [[60, 56]] # where 60 = input_channel * 5 * 6 + input.dims = [116, 1] # where 116 = 60 + 56 + + If set output_channel is 3, filter_size is [3, 3], stride is [1, 1]: + output.lod = [[90, 84]] # where 90 = output_channel * [(5-1)/stride + 1] * [(6-1)/stride + 1] + output.dims = [174, 1] # where 174 = 90 + 84 + + Args: + input (Variable): The input shoud be 1-level LodTensor with dims[1] equals 1. + row (Variable): The row shoud be 1-level LodTensor to provide height information. + col (Variable): The col shoud be 1-level LodTensor to provide width information. + input_channel (int): The number of input channel. + output_channel (int): The number of output channel. + filter_size (int|tuple|None): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + stride (int|tuple): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: stride = 1. + param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + of var_conv2d. If it is set to None or one attribute of ParamAttr, var_conv2d + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with :math:`Normal(0.0, std)`, + and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + act (str): Activation type, if it is set to None, activation is not appended. + Default: None + dtype ('float32'): The data type of parameter and output. + name (str|None): A name for this layer(optional). If set None, the layer + will be named automatically. Default: None + + Returns: + Variable: Output variable with LoD specified by this layer. + + Examples: + .. code-block:: python + + import numpy as np + from paddle.fluid import layers + + x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1) + row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1) + col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1) + out = layers.var_conv_2d(input=x_lod_tensor, + row=row_lod_tensor, + col=col_lod_tensor, + input_channel=3, + output_channel=5, + filter_size=[3, 3], + stride=1) + """ + helper = LayerHelper('var_conv_2d', **locals()) + x_shape = list(input.shape) + assert len(x_shape) == 2 + + filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') + stride = utils.convert_to_list(stride, 2, 'stride') + + filter_shape = [ + int(output_channel), + int(input_channel) * filter_size[0] * filter_size[1] + ] + filter_param = helper.create_parameter( + attr=helper.param_attr, + shape=filter_shape, + dtype=dtype, ) + + conv_res = helper.create_variable_for_type_inference(dtype) + tmp_res = helper.create_variable_for_type_inference( + dtype, stop_gradient=True) + + helper.append_op( + type='var_conv_2d', + inputs={ + 'X': input, + 'ROW': row, + 'COLUMN': col, + 'W': filter_param, + }, + outputs={"Out": conv_res, + "Col": tmp_res}, + attrs={ + 'InputChannel': input_channel, + 'OutputChannel': output_channel, + 'StrideH': stride[0], + 'StrideW': stride[1], + 'KernelH': filter_size[0], + 'KernelW': filter_size[1], + }) + + return helper.append_activation(conv_res) + + def shard_index(input, index_num, nshards, shard_id, ignore_value=-1): """ This layer creates the sharded index for input. This layers is used in diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f706c9ce685f0a02f96d8b1ce6ba37d85e95304b..bc98523d85e04d82e7f93f04b9e1f0d93b1d3814 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -74,6 +74,11 @@ if(NOT WITH_MKLML) list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op) endif() +if(NOT WITH_MKL) + list(REMOVE_ITEM TEST_OPS test_var_conv_2d) +endif(NOT WITH_MKL) + + if(WITH_GPU OR NOT WITH_MKLML) # matmul with multiple heads need MKL support LIST(REMOVE_ITEM TEST_OPS test_matmul_op_with_head) diff --git a/python/paddle/fluid/tests/unittests/test_var_conv_2d.py b/python/paddle/fluid/tests/unittests/test_var_conv_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e2db388318541801ac03c747be531fab882aa831 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_var_conv_2d.py @@ -0,0 +1,271 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestVarConv2dOp(OpTest): + def setUp(self): + self.init_op_type() + self.set_data() + self.compute() + + def init_op_type(self): + self.op_type = "var_conv_2d" + + def set_data(self): + input_channel = 3 + output_channel = 2 + filter_size = [2, 3] + stride = [1, 1] + row = [2, 4] + col = [3, 2] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + def init_data(self, input_channel, output_channel, filter_size, stride, row, + col): + + feature = [row[i] * col[i] for i in range(len(row))] + numel = sum(feature) * input_channel + x_data = np.random.random((numel, 1)).astype('float32') + x_lod = [[x * input_channel for x in feature]] + row_data = np.random.random((sum(row), 10)).astype('float32') + col_data = np.random.random((sum(col), 10)).astype('float32') + w_shape = (output_channel, + input_channel * filter_size[0] * filter_size[1]) + w_data = np.random.random(w_shape).astype('float32') + self.inputs = { + 'X': (x_data, x_lod), + 'ROW': (row_data, [row]), + 'COLUMN': (col_data, [col]), + 'W': w_data + } + self.attrs = { + 'InputChannel': input_channel, + 'OutputChannel': output_channel, + 'StrideH': stride[0], + 'StrideW': stride[1], + 'KernelH': filter_size[0], + 'KernelW': filter_size[1], + } + + def compute(self): + in_ch = self.attrs['InputChannel'] + out_ch = self.attrs['OutputChannel'] + kernel_h = self.attrs['KernelH'] + kernel_w = self.attrs['KernelW'] + stride_h = self.attrs['StrideH'] + stride_w = self.attrs['StrideW'] + row_data, row_lod = self.inputs['ROW'] + col_data, col_lod = self.inputs['COLUMN'] + x_data, x_lod = self.inputs['X'] + w_data = self.inputs['W'] + out_data = np.zeros((0, 1)).astype('float32') + + col_res_data, col_res_lod = self.Im2Col() + out_lod = [[]] + col_data_offset = 0 + batch_size = len(x_lod[0]) + for idx in range(batch_size): + width = col_lod[0][idx] + height = row_lod[0][idx] + top_im_x = 0 + if width != 0: + top_im_x = (width - 1) // stride_w + 1 + top_im_y = 0 + if height != 0: + top_im_y = (height - 1) // stride_h + 1 + top_im_size = top_im_x * top_im_y + out_lod[0].append(out_ch * top_im_size) + if top_im_size == 0: + out_tmp = np.zeros((out_ch * top_im_size, 1)).astype('float32') + else: + col_batch_data = col_res_data[col_data_offset:col_data_offset + + col_res_lod[0][idx]] + gemm_shape = (in_ch * kernel_h * kernel_w, top_im_size) + col_batch_data = col_batch_data.reshape(gemm_shape) + out_tmp = np.dot(w_data, col_batch_data).reshape(-1, 1) + out_data = np.vstack((out_data, out_tmp)) + + col_data_offset += col_res_lod[0][idx] + + self.outputs = { + 'Out': (out_data.astype('float32'), out_lod), + 'Col': (col_res_data, col_res_lod) + } + + def Im2Col(self): + in_ch = self.attrs['InputChannel'] + kernel_h = self.attrs['KernelH'] + kernel_w = self.attrs['KernelW'] + stride_h = self.attrs['StrideH'] + stride_w = self.attrs['StrideW'] + row_data, row_lod = self.inputs['ROW'] + col_data, col_lod = self.inputs['COLUMN'] + x_data, x_lod = self.inputs['X'] + col_res_lod = [[]] + top_size = 0 + batch_size = len(x_lod[0]) + for idx in range(batch_size): + width = col_lod[0][idx] + height = row_lod[0][idx] + top_im_x = 0 + if width != 0: + top_im_x = (width - 1) // stride_w + 1 + top_im_y = 0 + if height != 0: + top_im_y = (height - 1) // stride_h + 1 + top_x = top_im_x * top_im_y + top_y = in_ch * kernel_h * kernel_w + col_res_lod[0].append(top_x * top_y) + top_size += top_x * top_y + + col_res = np.zeros((top_size, 1)).astype('float32') + + kernel_win_size = kernel_h * kernel_w + half_kernel_h = kernel_h // 2 + half_kernel_w = kernel_w // 2 + t_offset, b_offset = 0, 0 + for idx in range(batch_size): + width = col_lod[0][idx] + height = row_lod[0][idx] + if width == 0 or height == 0: + continue + top_im_x = (width - 1) // stride_w + 1 + top_im_y = (height - 1) // stride_h + 1 + top_x = top_im_x * top_im_y + for z in range(in_ch): + row_offset = kernel_win_size * z + im_offset = z * width * height + for y in range(0, height, stride_h): + for x in range(0, width, stride_w): + col_offset = x // stride_w + y // stride_h * top_im_x + for ky in range(kernel_h): + for kx in range(kernel_w): + im_y = y + ky - half_kernel_h + im_x = x + kx - half_kernel_w + if im_x >= 0 and im_x < width and im_y >= 0 and im_y < height: + col_res[t_offset + + (row_offset + ky * kernel_w + kx) * top_x + + col_offset] = \ + x_data[b_offset + im_offset + im_y * width + im_x] + + t_offset += col_res_lod[0][idx] + b_offset += x_lod[0][idx] + + return col_res, col_res_lod + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.005) + + +class TestVarConv2dOpCase1(TestVarConv2dOp): + def set_data(self): + # set in_ch 1 + input_channel = 1 + output_channel = 2 + filter_size = [2, 3] + stride = [1, 1] + row = [1, 4] + col = [3, 2] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + +class TestVarConv2dOpCase2(TestVarConv2dOp): + def set_data(self): + # set out_ch 1 + input_channel = 2 + output_channel = 1 + filter_size = [3, 3] + stride = [2, 2] + row = [4, 7] + col = [5, 2] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + +class TestVarConv2dOpCase3(TestVarConv2dOp): + def set_data(self): + # set batch 1 + input_channel = 2 + output_channel = 1 + filter_size = [3, 3] + stride = [2, 2] + row = [7] + col = [2] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + +class TestVarConv2dOpCase4(TestVarConv2dOp): + def set_data(self): + # set filter size very large + input_channel = 3 + output_channel = 4 + filter_size = [6, 6] + stride = [2, 2] + row = [4, 7] + col = [5, 2] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + +class TestVarConv2dOpCase5(TestVarConv2dOp): + def set_data(self): + # set input very small + input_channel = 5 + output_channel = 3 + filter_size = [3, 3] + stride = [1, 1] + row = [1, 1] + col = [1, 1] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + +class TestVarConv2dOpCase6(TestVarConv2dOp): + def set_data(self): + input_channel = 1 + output_channel = 3 + filter_size = [3, 3] + stride = [1, 1] + row = [1, 1] + col = [1, 1] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + +class TestVarConv2dOpCase7(TestVarConv2dOp): + def set_data(self): + input_channel = 2 + output_channel = 3 + filter_size = [3, 3] + stride = [1, 1] + row = [5, 4] + col = [6, 7] + self.init_data(input_channel, output_channel, filter_size, stride, row, + col) + + +if __name__ == '__main__': + unittest.main()