提交 e681d655 编写于 作者: K Kevin 提交者: Tao Luo

Add var_conv_2d op (#18518)

* fix overflow by int32 mul test=develop

* fix reference nullptr

* fix codestyle test=develop

* modify to point in ContextProjectFunctor test=develop

* modify to point in ContextProjectFunctor test=develop

* modify . to -> test=develop

* add var_conv_2d op test=develop

* edit api.spec test=develop

* ignore unittest if with_mkl=off test=develop

* fix python3 division test=develop

* fix ignore unittest bug test=develop

* remove useless code test=develop

* modify api.spec test=develop

* modify default_grad.spec test=develop
上级 81fe02c3
......@@ -267,6 +267,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cccb6eb5410c822e5307c947aca2c899'))
......
......@@ -41,3 +41,4 @@ tensor_array_to_tensor
transpose
unpool
unsqueeze
var_conv_2d
......@@ -48,8 +48,13 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
SET(OP_ONLY_MKL "")
if (NOT WITH_MKL)
SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op deformable_conv_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
sync_batch_norm_op deformable_conv_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/var_conv_2d_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
void VarConv2dOpMaker::Make() {
AddInput("X",
"X (LoDTensor, default LoDTensor<float>) Input variable which "
"should contain lod information.");
AddInput("ROW", "(LoDTensor) the row variable provides lod information");
AddInput("COLUMN",
"(LoDTensor) the column variable provides lod information");
AddInput("W", "W (Tensor), the filter.");
AddAttr<int>("InputChannel", "the input filter num").SetDefault(1);
AddAttr<int>("OutputChannel", "the output filter num").SetDefault(1);
AddAttr<int>("StrideH", "the height of Stride").SetDefault(1);
AddAttr<int>("StrideW", "the width of Stride").SetDefault(1);
AddAttr<int>("KernelH", "the height of Kernel").SetDefault(1);
AddAttr<int>("KernelW", "the width of Kernel").SetDefault(1);
AddOutput("Out", "(LoDTensor, default LoDTensor<float>) Output variable");
AddOutput("Col",
"(LoDTensor, default LoDTensor<float>) the intermediate result "
"variable");
AddComment(R"DOC(
Var Size Conv Operator
This operator calculate Out = \sigma \left ( W * X + b \right ),
only support 2-D for X.
NOTE: only support 'float32' data type now.
)DOC");
}
void VarConv2dOP::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROW"),
"Input(ROW) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("COLUMN"),
"Input(COLUMN) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of VarConv2dOP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Col"),
"Col(Output) of VarConv2dOP should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of X(Input) can't be less than 2.");
auto w_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "W should be 2-D tensor");
int output_channel = ctx->Attrs().Get<int>("OutputChannel");
int input_channel = ctx->Attrs().Get<int>("InputChannel");
int kernel_h = ctx->Attrs().Get<int>("KernelH");
int kernel_w = ctx->Attrs().Get<int>("KernelW");
PADDLE_ENFORCE_EQ(w_dims[0], output_channel,
"W dim[0] should be equal to OutputChannel");
PADDLE_ENFORCE_EQ(
w_dims[1], input_channel * kernel_h * kernel_w,
"W dim[1] should be equal to InputChannel * StrideH * StrideW");
if (ctx->IsRuntime()) {
framework::Variable* x_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
const auto& x_lod = x_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info.");
PADDLE_ENFORCE_GE(x_lod.size(), 1, "The Input(X)'s lod info is corrupted.");
PADDLE_ENFORCE_EQ(
x_dims[0], static_cast<int64_t>(x_lod[0].back()),
"The Input(X)'s lod info mismatches the actual tensor shape.");
framework::Variable* row_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("ROW")[0]);
const auto& row_lod = row_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!row_lod.empty(), "The Input(ROW) must hold lod info.");
framework::Variable* col_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("COLUMN")[0]);
const auto& col_lod = col_var->Get<LoDTensor>().lod();
PADDLE_ENFORCE(!col_lod.empty(), "The Input(COLUMN) must hold lod info.");
} else {
std::vector<int64_t> out_dims_vec{-1};
out_dims_vec.push_back(1);
std::vector<int64_t> col_dims_vec{-1};
col_dims_vec.push_back(1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
ctx->SetOutputDim("Col", framework::make_ddim(col_dims_vec));
}
}
template <typename DeviceContext, typename T>
class CPUVarConv2dOPKernel : public framework::OpKernel<T> {
public:
void Im2Col(const framework::ExecutionContext& ctx, const LoDTensor& input,
LoDTensor* col) const {
int input_channel = ctx.Attr<int>("InputChannel");
auto* in_row = ctx.Input<LoDTensor>("ROW");
auto* in_col = ctx.Input<LoDTensor>("COLUMN");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
int stride_h = ctx.Attr<int>("StrideH");
int stride_w = ctx.Attr<int>("StrideW");
int batch = input.lod()[0].size() - 1;
const auto& bottom_offset = input.lod()[0];
// 2-D lod info.
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
// top offset is the whole size of each data sample
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_x = top_im_y * top_im_x;
int top_y = input_channel * kernel_h * kernel_w;
top_size += top_y * top_x;
top_offset.push_back(top_size);
}
framework::LoD col_lod;
col_lod.push_back(top_offset);
col->set_lod(col_lod);
std::vector<int64_t> col_dims_vec{top_size};
col_dims_vec.push_back(1);
auto* top_data = col->mutable_data<T>(framework::make_ddim(col_dims_vec),
ctx.GetPlace());
auto* bottom_data = input.data<T>();
int kernel_win_size = kernel_h * kernel_w;
int half_kernel_h = kernel_h / 2;
int half_kernel_w = kernel_w / 2;
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int top_im_x = (width - 1) / stride_w + 1;
int top_im_y = (height - 1) / stride_h + 1;
int top_x = top_im_y * top_im_x;
for (int z = 0; z < input_channel; ++z) {
int row_offset = kernel_win_size * z;
int im_offset = z * width * height;
for (int y = 0; y < height; y += stride_h) {
for (int x = 0; x < width; x += stride_w) {
int col_offset = x / stride_w + y / stride_h * top_im_x;
for (int ky = 0; ky < kernel_h; ++ky) {
for (int kx = 0; kx < kernel_w; ++kx) {
int im_y = y + ky - half_kernel_h;
int im_x = x + kx - half_kernel_w;
if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) {
top_data[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset] =
bottom_data[b_offset + im_offset + im_y * width + im_x];
} else {
top_data[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset] = 0;
}
}
}
}
}
}
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* bottom = ctx.Input<LoDTensor>("X");
auto* in_row = ctx.Input<LoDTensor>("ROW");
auto* in_col = ctx.Input<LoDTensor>("COLUMN");
auto* w = ctx.Input<Tensor>("W");
auto* top = ctx.Output<LoDTensor>("Out");
auto* col = ctx.Output<LoDTensor>("Col");
int output_channel = ctx.Attr<int>("OutputChannel");
int input_channel = ctx.Attr<int>("InputChannel");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
int stride_h = ctx.Attr<int>("StrideH");
int stride_w = ctx.Attr<int>("StrideW");
Im2Col(ctx, *bottom, col);
int batch = bottom->lod()[0].size() - 1;
const auto& col_offset = col->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_im_size = top_im_y * top_im_x;
top_size += output_channel * top_im_size;
top_offset.push_back(top_size);
}
framework::LoD top_lod;
top_lod.push_back(top_offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{top_size};
top_dims_vec.push_back(1);
auto* top_data = top->mutable_data<T>(framework::make_ddim(top_dims_vec),
ctx.GetPlace());
auto* w_data = w->data<T>();
auto* col_data = col->data<T>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
blas.GEMM(CblasNoTrans, CblasNoTrans, output_channel, top_im_size,
input_channel * kernel_h * kernel_w, 1.0, w_data,
col_data + col_offset[b], 0.0, top_data + top_offset[b]);
}
}
};
void VarConv2dOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of SequencePadGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
}
if (ctx->HasOutput(framework::GradVarName("W"))) {
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
}
}
template <typename DeviceContext, typename T>
class CPUVarConv2dOPGradKernel : public framework::OpKernel<T> {
public:
void Im2ColGrad(const framework::ExecutionContext& ctx, T* top_diff) const {
auto* x = ctx.Input<LoDTensor>("X");
auto* in_row = ctx.Input<LoDTensor>("ROW");
auto* in_col = ctx.Input<LoDTensor>("COLUMN");
auto* col = ctx.Input<LoDTensor>("Col");
int input_channel = ctx.Attr<int>("InputChannel");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
int stride_h = ctx.Attr<int>("StrideH");
int stride_w = ctx.Attr<int>("StrideW");
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
memset(dx_data, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T));
const auto& bottom_offset = x->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
const auto& top_offset = col->lod()[0];
int batch = x->lod()[0].size() - 1;
int kernel_win_size = kernel_h * kernel_w;
int half_kernel_h = kernel_h / 2;
int half_kernel_w = kernel_w / 2;
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int top_im_x = (width - 1) / stride_w + 1;
int top_im_y = (height - 1) / stride_h + 1;
int top_x = top_im_y * top_im_x;
for (int z = 0; z < input_channel; ++z) {
int row_offset = kernel_win_size * z;
int im_offset = z * width * height;
for (int y = 0; y < height; y += stride_h) {
for (int x = 0; x < width; x += stride_w) {
int col_offset = x / stride_w + y / stride_h * top_im_x;
for (int ky = 0; ky < kernel_h; ++ky) {
for (int kx = 0; kx < kernel_w; ++kx) {
int im_y = y + ky - half_kernel_h;
int im_x = x + kx - half_kernel_w;
if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) {
dx_data[b_offset + im_offset + im_y * width + im_x] +=
top_diff[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset];
}
}
}
}
}
}
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("W");
auto* col = ctx.Input<LoDTensor>("Col");
auto* out = ctx.Input<LoDTensor>("Out");
int output_channel = ctx.Attr<int>("OutputChannel");
int input_channel = ctx.Attr<int>("InputChannel");
int kernel_h = ctx.Attr<int>("KernelH");
int kernel_w = ctx.Attr<int>("KernelW");
auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* d_w = ctx.Output<Tensor>(framework::GradVarName("W"));
Tensor col_grad;
col_grad.Resize(col->dims());
auto* col_diff = col_grad.mutable_data<T>(ctx.GetPlace());
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* w_diff = d_w->mutable_data<T>(ctx.GetPlace());
memset(dx_data, 0.0, x->dims()[0] * x->dims()[1] * sizeof(T));
memset(w_diff, 0.0, w->dims()[0] * w->dims()[1] * sizeof(T));
memset(col_diff, 0.0, col->dims()[0] * col->dims()[1] * sizeof(T));
auto* top_diff = d_out->data<T>();
auto* w_data = w->data<T>();
auto* col_data = col->data<T>();
int batch = x->lod()[0].size() - 1;
const auto& top_offset = out->lod()[0];
const auto& col_offset = col->lod()[0];
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
blas.GEMM(CblasTrans, CblasNoTrans, input_channel * kernel_h * kernel_w,
top_im_size, output_channel, 1.0, w_data,
top_diff + top_offset[b], 1.0, col_diff + col_offset[b]);
blas.GEMM(CblasNoTrans, CblasTrans, output_channel,
input_channel * kernel_h * kernel_w, top_im_size, 1.0,
top_diff + top_offset[b], col_data + col_offset[b], 1.0,
w_diff);
}
Im2ColGrad(ctx, col_diff);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plt = paddle::platform;
namespace frm = paddle::framework;
REGISTER_OPERATOR(var_conv_2d, ops::VarConv2dOP, ops::VarConv2dOpMaker,
frm::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(var_conv_2d_grad, ops::VarConv2dOpGrad);
REGISTER_OP_CPU_KERNEL(var_conv_2d,
ops::CPUVarConv2dOPKernel<plt::CPUDeviceContext, float>);
// ops::CPUVarConv2dOPKernel<plt::CPUDeviceContext,
// double>
REGISTER_OP_CPU_KERNEL(
var_conv_2d_grad,
ops::CPUVarConv2dOPGradKernel<plt::CPUDeviceContext, float>);
// ops::CPUVarConv2dOPGradKernel<plt::CPUDeviceContext,
// double>
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
class VarConv2dOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
class VarConv2dOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
class VarConv2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -210,6 +210,7 @@ __all__ = [
'deformable_conv',
'unfold',
'deformable_roi_pooling',
'var_conv_2d',
'shard_index',
]
......@@ -12729,6 +12730,121 @@ def deformable_roi_pooling(input,
return output
def var_conv_2d(input,
row,
col,
input_channel,
output_channel,
filter_size,
stride=1,
param_attr=None,
act=None,
dtype='float32',
name=None):
"""
The var_conv_2d layer calculates the output base on the :attr:`input` with variable length,
row, col, input channel, filter size and strides. Both :attr:`input`, :attr:`row`,
and :attr:`col` are 1-level LodTensor. The covolution operation is same as conv2d layer with
padding. Besides, input.dims[1] should be 1.
.. code-block:: text
If input_channel is 2 and given row lodTensor and col lodTensor as follows:
row.lod = [[5, 4]]
col.lod = [[6, 7]]
input is a lodTensor:
input.lod = [[60, 56]] # where 60 = input_channel * 5 * 6
input.dims = [116, 1] # where 116 = 60 + 56
If set output_channel is 3, filter_size is [3, 3], stride is [1, 1]:
output.lod = [[90, 84]] # where 90 = output_channel * [(5-1)/stride + 1] * [(6-1)/stride + 1]
output.dims = [174, 1] # where 174 = 90 + 84
Args:
input (Variable): The input shoud be 1-level LodTensor with dims[1] equals 1.
row (Variable): The row shoud be 1-level LodTensor to provide height information.
col (Variable): The col shoud be 1-level LodTensor to provide width information.
input_channel (int): The number of input channel.
output_channel (int): The number of output channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of var_conv2d. If it is set to None or one attribute of ParamAttr, var_conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
dtype ('float32'): The data type of parameter and output.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: Output variable with LoD specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = layers.var_conv_2d(input=x_lod_tensor,
row=row_lod_tensor,
col=col_lod_tensor,
input_channel=3,
output_channel=5,
filter_size=[3, 3],
stride=1)
"""
helper = LayerHelper('var_conv_2d', **locals())
x_shape = list(input.shape)
assert len(x_shape) == 2
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
filter_shape = [
int(output_channel),
int(input_channel) * filter_size[0] * filter_size[1]
]
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype, )
conv_res = helper.create_variable_for_type_inference(dtype)
tmp_res = helper.create_variable_for_type_inference(
dtype, stop_gradient=True)
helper.append_op(
type='var_conv_2d',
inputs={
'X': input,
'ROW': row,
'COLUMN': col,
'W': filter_param,
},
outputs={"Out": conv_res,
"Col": tmp_res},
attrs={
'InputChannel': input_channel,
'OutputChannel': output_channel,
'StrideH': stride[0],
'StrideW': stride[1],
'KernelH': filter_size[0],
'KernelW': filter_size[1],
})
return helper.append_activation(conv_res)
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
This layer creates the sharded index for input. This layers is used in
......
......@@ -74,6 +74,11 @@ if(NOT WITH_MKLML)
list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
endif()
if(NOT WITH_MKL)
list(REMOVE_ITEM TEST_OPS test_var_conv_2d)
endif(NOT WITH_MKL)
if(WITH_GPU OR NOT WITH_MKLML)
# matmul with multiple heads need MKL support
LIST(REMOVE_ITEM TEST_OPS test_matmul_op_with_head)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestVarConv2dOp(OpTest):
def setUp(self):
self.init_op_type()
self.set_data()
self.compute()
def init_op_type(self):
self.op_type = "var_conv_2d"
def set_data(self):
input_channel = 3
output_channel = 2
filter_size = [2, 3]
stride = [1, 1]
row = [2, 4]
col = [3, 2]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
def init_data(self, input_channel, output_channel, filter_size, stride, row,
col):
feature = [row[i] * col[i] for i in range(len(row))]
numel = sum(feature) * input_channel
x_data = np.random.random((numel, 1)).astype('float32')
x_lod = [[x * input_channel for x in feature]]
row_data = np.random.random((sum(row), 10)).astype('float32')
col_data = np.random.random((sum(col), 10)).astype('float32')
w_shape = (output_channel,
input_channel * filter_size[0] * filter_size[1])
w_data = np.random.random(w_shape).astype('float32')
self.inputs = {
'X': (x_data, x_lod),
'ROW': (row_data, [row]),
'COLUMN': (col_data, [col]),
'W': w_data
}
self.attrs = {
'InputChannel': input_channel,
'OutputChannel': output_channel,
'StrideH': stride[0],
'StrideW': stride[1],
'KernelH': filter_size[0],
'KernelW': filter_size[1],
}
def compute(self):
in_ch = self.attrs['InputChannel']
out_ch = self.attrs['OutputChannel']
kernel_h = self.attrs['KernelH']
kernel_w = self.attrs['KernelW']
stride_h = self.attrs['StrideH']
stride_w = self.attrs['StrideW']
row_data, row_lod = self.inputs['ROW']
col_data, col_lod = self.inputs['COLUMN']
x_data, x_lod = self.inputs['X']
w_data = self.inputs['W']
out_data = np.zeros((0, 1)).astype('float32')
col_res_data, col_res_lod = self.Im2Col()
out_lod = [[]]
col_data_offset = 0
batch_size = len(x_lod[0])
for idx in range(batch_size):
width = col_lod[0][idx]
height = row_lod[0][idx]
top_im_x = 0
if width != 0:
top_im_x = (width - 1) // stride_w + 1
top_im_y = 0
if height != 0:
top_im_y = (height - 1) // stride_h + 1
top_im_size = top_im_x * top_im_y
out_lod[0].append(out_ch * top_im_size)
if top_im_size == 0:
out_tmp = np.zeros((out_ch * top_im_size, 1)).astype('float32')
else:
col_batch_data = col_res_data[col_data_offset:col_data_offset +
col_res_lod[0][idx]]
gemm_shape = (in_ch * kernel_h * kernel_w, top_im_size)
col_batch_data = col_batch_data.reshape(gemm_shape)
out_tmp = np.dot(w_data, col_batch_data).reshape(-1, 1)
out_data = np.vstack((out_data, out_tmp))
col_data_offset += col_res_lod[0][idx]
self.outputs = {
'Out': (out_data.astype('float32'), out_lod),
'Col': (col_res_data, col_res_lod)
}
def Im2Col(self):
in_ch = self.attrs['InputChannel']
kernel_h = self.attrs['KernelH']
kernel_w = self.attrs['KernelW']
stride_h = self.attrs['StrideH']
stride_w = self.attrs['StrideW']
row_data, row_lod = self.inputs['ROW']
col_data, col_lod = self.inputs['COLUMN']
x_data, x_lod = self.inputs['X']
col_res_lod = [[]]
top_size = 0
batch_size = len(x_lod[0])
for idx in range(batch_size):
width = col_lod[0][idx]
height = row_lod[0][idx]
top_im_x = 0
if width != 0:
top_im_x = (width - 1) // stride_w + 1
top_im_y = 0
if height != 0:
top_im_y = (height - 1) // stride_h + 1
top_x = top_im_x * top_im_y
top_y = in_ch * kernel_h * kernel_w
col_res_lod[0].append(top_x * top_y)
top_size += top_x * top_y
col_res = np.zeros((top_size, 1)).astype('float32')
kernel_win_size = kernel_h * kernel_w
half_kernel_h = kernel_h // 2
half_kernel_w = kernel_w // 2
t_offset, b_offset = 0, 0
for idx in range(batch_size):
width = col_lod[0][idx]
height = row_lod[0][idx]
if width == 0 or height == 0:
continue
top_im_x = (width - 1) // stride_w + 1
top_im_y = (height - 1) // stride_h + 1
top_x = top_im_x * top_im_y
for z in range(in_ch):
row_offset = kernel_win_size * z
im_offset = z * width * height
for y in range(0, height, stride_h):
for x in range(0, width, stride_w):
col_offset = x // stride_w + y // stride_h * top_im_x
for ky in range(kernel_h):
for kx in range(kernel_w):
im_y = y + ky - half_kernel_h
im_x = x + kx - half_kernel_w
if im_x >= 0 and im_x < width and im_y >= 0 and im_y < height:
col_res[t_offset +
(row_offset + ky * kernel_w + kx) * top_x +
col_offset] = \
x_data[b_offset + im_offset + im_y * width + im_x]
t_offset += col_res_lod[0][idx]
b_offset += x_lod[0][idx]
return col_res, col_res_lod
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.005)
class TestVarConv2dOpCase1(TestVarConv2dOp):
def set_data(self):
# set in_ch 1
input_channel = 1
output_channel = 2
filter_size = [2, 3]
stride = [1, 1]
row = [1, 4]
col = [3, 2]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
class TestVarConv2dOpCase2(TestVarConv2dOp):
def set_data(self):
# set out_ch 1
input_channel = 2
output_channel = 1
filter_size = [3, 3]
stride = [2, 2]
row = [4, 7]
col = [5, 2]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
class TestVarConv2dOpCase3(TestVarConv2dOp):
def set_data(self):
# set batch 1
input_channel = 2
output_channel = 1
filter_size = [3, 3]
stride = [2, 2]
row = [7]
col = [2]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
class TestVarConv2dOpCase4(TestVarConv2dOp):
def set_data(self):
# set filter size very large
input_channel = 3
output_channel = 4
filter_size = [6, 6]
stride = [2, 2]
row = [4, 7]
col = [5, 2]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
class TestVarConv2dOpCase5(TestVarConv2dOp):
def set_data(self):
# set input very small
input_channel = 5
output_channel = 3
filter_size = [3, 3]
stride = [1, 1]
row = [1, 1]
col = [1, 1]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
class TestVarConv2dOpCase6(TestVarConv2dOp):
def set_data(self):
input_channel = 1
output_channel = 3
filter_size = [3, 3]
stride = [1, 1]
row = [1, 1]
col = [1, 1]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
class TestVarConv2dOpCase7(TestVarConv2dOp):
def set_data(self):
input_channel = 2
output_channel = 3
filter_size = [3, 3]
stride = [1, 1]
row = [5, 4]
col = [6, 7]
self.init_data(input_channel, output_channel, filter_size, stride, row,
col)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册