未验证 提交 4fc9f55e 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #5472 from chengduoZH/refine_im2col

Add dilations for conv2d and optimize conv2d code
...@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker { ...@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
CudnnConvOpMaker(framework::OpProto* proto, CudnnConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) { : Conv2DOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddAttr<int>("workspace_size_MB", AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, " "workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be " "workspace is a section of GPU memory which will be "
......
...@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int input_channels = in_dims[1]; int input_channels = in_dims[1];
int output_channels = filter_dims[0]; int output_channels = filter_dims[0];
...@@ -52,9 +53,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -52,9 +53,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"The number of output channels should be divided by groups."); "The number of output channels should be divided by groups.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] -
(dilations[i] * (filter_dims[i + 2] - 1) + 1) >
0,
"Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check "
"again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
paddings[i], strides[i])); dilations[i], paddings[i], strides[i]));
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
...@@ -78,9 +85,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, ...@@ -78,9 +85,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", "strides of convolution operator.") AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.") AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0, 0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<int>( AddAttr<int>(
"groups", "groups",
...@@ -90,15 +103,20 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, ...@@ -90,15 +103,20 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters " "first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution Operator. Convolution Operator.
The convolution operation calculates the output based on the input, filter The convolution operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the and strides, paddings, groups, dilations parameters. The size of each dimension of the
parameters is checked in the infer-shape. parameters is checked in the infer-shape.
Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is size, C is the number of channels, H is the height of the feature, and W is
the width of the feature. Parameters(ksize, strides, paddings) are two elements. the width of the feature. Parameters(ksize, strides, paddings, dilations) are two elements.
These two elements represent height and width, respectively. These two elements represent height and width, respectively.
The input(X) size and output(Out) size may be different. The input(X) size and output(Out) size may be different.
...@@ -109,8 +127,8 @@ Example: ...@@ -109,8 +127,8 @@ Example:
Output: Output:
Output shape: (N, C_out, H_out, W_out) Output shape: (N, C_out, H_out, W_out)
where where
H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; H_out = (H_in + 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)) / strides[0] + 1;
W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; W_out = (W_in + 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)) / strides[1] + 1;
)DOC"); )DOC");
} }
...@@ -135,13 +153,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, ...@@ -135,13 +153,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator." "(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides", "(vector<int>, default:{1, 1, 1}), the "
"(vector, default:{0, 0, 0}), the strides of convolution operator.") "strides(d_stride, h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("paddings",
"paddings", "(vector<int>, default:{0, 0, 0}), the "
"(vector, default:{0, 0, 0}), the paddings of convolution operator.") "paddings(d_pad, h_pad, w_pad) of convolution "
"operator.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0});
AddAttr<int>( AddAttr<int>(
"groups", "groups",
...@@ -151,6 +171,12 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, ...@@ -151,6 +171,12 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters " "first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation, h_dilation, w_dilation) of "
"convolution operator. Currently, conv3d doesn't "
"support dilation.")
.SetDefault({1, 1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution3D Operator. Convolution3D Operator.
......
...@@ -27,11 +27,24 @@ using Tensor = framework::Tensor; ...@@ -27,11 +27,24 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv // Base convolution operator definations for other conv
// like operators to reuse the implementation. // like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int padding, inline int OutputSize(int input_size, int filter_size, int dilation,
int stride) { int padding, int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1; const int dkernel = dilation * (filter_size - 1) + 1;
const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
return output_size; return output_size;
} }
inline bool IsExpand(std::vector<int64_t>& filter_dim,
std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
// Define Op classes in .h file so that other conv // Define Op classes in .h file so that other conv
// operator implementations can reuse the code. // operator implementations can reuse the code.
...@@ -50,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvOp : public framework::OperatorWithKernel { class ConvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
class ConvOpGrad : public framework::OperatorWithKernel { class ConvOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
...@@ -73,9 +84,10 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -73,9 +84,10 @@ class GemmConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -106,14 +118,17 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -106,14 +118,17 @@ class GemmConvKernel : public framework::OpKernel<T> {
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col, // col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape // but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace());
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size())); input->dims(), 1, static_cast<int>(input->dims().size()));
...@@ -130,24 +145,30 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -130,24 +145,30 @@ class GemmConvKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups; int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<Place, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) { if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (filter_shape_vec.size() == 2) {
// im2col // im2col
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; im2col(context.device_context(), in_slice, dilations, strides,
im2col(context.device_context(), in_slice, col, strides[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]},
paddings[1]); &col);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// vol2col // vol2col
math::Vol2ColFunctor<Place, T> vol2col; vol2col(context.device_context(), in_slice, dilations, strides,
vol2col(context.device_context(), in_slice, col, strides[0], paddings, &col);
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
} }
// gemm // gemm
...@@ -178,9 +199,10 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -178,9 +199,10 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (!input_grad && !filter_grad) return; if (!input_grad && !filter_grad) return;
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -230,14 +252,17 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -230,14 +252,17 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output_grad->dims()[1]) / groups; int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
// col_matrix shares the same piece of data with col, // col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape // but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
}
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
...@@ -245,6 +270,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -245,6 +270,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0)); set_zero(context.device_context(), input_grad, static_cast<T>(0));
math::Col2VolFunctor<Place, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape); output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -254,24 +282,26 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -254,24 +282,26 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
Tensor out_grad_slice = Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step); out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
// col2im
Tensor in_grad_slice = Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step); in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) { if (!is_expand) {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; col_matrix.ShareDataWith(in_grad_slice);
col2im(context.device_context(), in_grad_slice, col, strides[0], col_matrix.Resize(col_matrix_shape);
strides[1], paddings[0], paddings[0], paddings[1], }
paddings[1]); math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
} else if (filter_shape_vec.size() == 3) { if (is_expand && filter_shape_vec.size() == 2) {
math::Col2VolFunctor<Place, T> col2vol; col2im(context.device_context(), col, dilations, strides,
col2vol(context.device_context(), in_grad_slice, col, strides[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
strides[1], strides[2], paddings[0], paddings[1], paddings[1]},
paddings[2]); &in_grad_slice);
} else if (is_expand && filter_shape_vec.size() == 3) {
col2vol(context.device_context(), col, dilations, strides, paddings,
&in_grad_slice);
} }
} }
} }
...@@ -282,7 +312,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -282,7 +312,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
Tensor filter_grad_ = *filter_grad; Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
set_zero(context.device_context(), filter_grad, static_cast<T>(0)); set_zero(context.device_context(), filter_grad, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape); output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -293,16 +324,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -293,16 +324,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
out_grad_batch.Slice(g * out_step, (g + 1) * out_step); out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) { if (!is_expand) {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; col.ShareDataWith(in_slice);
im2col(context.device_context(), in_slice, col, strides[0], col_matrix.ShareDataWith(col);
strides[1], paddings[0], paddings[0], paddings[1], col_matrix.Resize(col_matrix_shape);
paddings[1]); } else if (filter_shape_vec.size() == 2) {
im2col(context.device_context(), in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
math::Vol2ColFunctor<Place, T> vol2col; vol2col(context.device_context(), in_slice, dilations, strides,
vol2col(context.device_context(), in_slice, col, strides[0], paddings, &col);
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
} }
// gemm // gemm
......
...@@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"as the number of filters."); "as the number of filters.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] + output_shape.push_back((in_dims[i + 2] - 1) * strides[i] +
filter_dims[i + 2]); filter_dims[i + 2]);
} }
...@@ -79,11 +79,13 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( ...@@ -79,11 +79,13 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
"(vector defalut:{1, 1}), strides of convolution transpose operator.") "(vector<int> defalut:{1, 1}), the strides(h_stride, w_stride) of "
"convolution transpose operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector defalut:{0, 0}), paddings of convolution transpose operator.") "(vector<int> defalut:{0, 0}), the paddings(h_pad, w_pad) of convolution "
"transpose operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddComment(R"DOC( AddComment(R"DOC(
Convolution2D Transpose Operator. Convolution2D Transpose Operator.
...@@ -132,13 +134,14 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( ...@@ -132,13 +134,14 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the " "the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature."); "height of the feature, and W is the width of the feature.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides", "(vector<int> defalut:{1, 1, 1}), the "
"(vector defalut:{1, 1, 1}), strides of convolution transpose operator.") "strides{d_stride, h_stride, w_stride} of "
"convolution transpose operator.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("paddings",
"paddings", "(vector<int> defalut:{0, 0, 0}), paddings(d_pad, "
"(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.") "h_pad, w_pad) of convolution transpose operator.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0});
AddComment(R"DOC( AddComment(R"DOC(
Convolution3D Transpose Operator. Convolution3D Transpose Operator.
......
...@@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvTransposeOp : public framework::OperatorWithKernel { class ConvTransposeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
class ConvTransposeOpGrad : public framework::OperatorWithKernel { class ConvTransposeOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
...@@ -66,6 +62,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -66,6 +62,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
// TODO(Zhuoyuan): Paddings can be added in future. // TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose. // groups will alway be disabled in conv2dtranspose.
...@@ -120,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -120,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), output, static_cast<T>(0)); set_zero(context.device_context(), output, static_cast<T>(0));
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Col2VolFunctor<Place, T> col2vol;
std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input) // on input)
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
...@@ -138,16 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -138,16 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
// col2im: col_matrix -> dy // col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; col2im(context.device_context(), col,
std::vector<int>{dilations[0], dilations[1]}, strides,
col2im(context.device_context(), output_batch, col, strides[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
strides[1], 0, 0, 0, 0); paddings[1]},
&output_batch);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// col2vol: col_matrix -> dy // col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
math::Col2VolFunctor<Place, T> col2vol; col2vol(context.device_context(), col, dilations, strides,
col2vol(context.device_context(), output_batch, col, strides[0], std::vector<int>{0, 0, 0}, &output_batch);
strides[1], strides[2], 0, 0, 0);
} }
} }
} }
...@@ -228,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -228,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor filter_grad_; Tensor filter_grad_;
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
std::vector<int> dilations({1, 1, 1});
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0)); set_zero(context.device_context(), input_grad, static_cast<T>(0));
...@@ -247,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -247,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
// im2col: dy -> col matrix // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; im2col(context.device_context(), output_grad_batch,
im2col(context.device_context(), output_grad_batch, col, strides[0], std::vector<int>{dilations[0], dilations[1]}, strides,
strides[1], paddings[0], paddings[0], paddings[1], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]); paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// vol2col: dy -> col_matrix // vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
math::Vol2ColFunctor<Place, T> vol2col; vol2col(context.device_context(), output_grad_batch, dilations,
vol2col(context.device_context(), output_grad_batch, col, strides[0], strides, paddings, &col);
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
} }
if (input_grad) { if (input_grad) {
......
...@@ -88,13 +88,18 @@ template <typename Place, typename T> ...@@ -88,13 +88,18 @@ template <typename Place, typename T>
class ContextProjectFunctor { class ContextProjectFunctor {
public: public:
void operator()(const platform::DeviceContext& context, const LoDTensor& in, void operator()(const platform::DeviceContext& context, const LoDTensor& in,
const Tensor& padding_data, Tensor& col, const Tensor& padding_data, bool padding_trainable,
bool padding_trainable, int context_start, int context_length, const int context_start, const int context_length,
int context_stride, int up_pad, int down_pad) { const int context_stride, const int up_pad,
const int down_pad, Tensor* col) {
auto lod_level_0 = in.lod()[0]; auto lod_level_0 = in.lod()[0];
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf; math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
std::vector<int> dilation({1, 1});
std::vector<int> padding({up_pad, 0, down_pad, 0});
std::vector<int> stride({context_stride, 1});
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
sequence_width = in.dims()[1]; sequence_width = in.dims()[1];
...@@ -105,7 +110,7 @@ class ContextProjectFunctor { ...@@ -105,7 +110,7 @@ class ContextProjectFunctor {
: static_cast<int>(lod_level_0[i]); : static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]); input_row_end = static_cast<int>(lod_level_0[i + 1]);
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -123,16 +128,13 @@ class ContextProjectFunctor { ...@@ -123,16 +128,13 @@ class ContextProjectFunctor {
{1, input_row_end - input_row_begin, {1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context, in_t, dilation, stride, padding, &out_t);
im2col_ocf(context, in_t, out_t,
/*stride_height*/ context_stride, /*stride_width*/ 1, up_pad,
down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
} }
} }
if (padding_trainable) { if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -196,14 +198,19 @@ class ContextProjectFunctor { ...@@ -196,14 +198,19 @@ class ContextProjectFunctor {
template <typename Place, typename T> template <typename Place, typename T>
class ContextProjectGradFunctor { class ContextProjectGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, LoDTensor& in, void operator()(const platform::DeviceContext& context, const LoDTensor& in,
Tensor& padding_data, Tensor& col, bool padding_trainable, bool padding_trainable, const int context_start,
int context_start, int context_length, int context_stride, const int context_length, const int context_stride,
int up_pad, int down_pad, bool input_grad, bool pad_grad) { const int up_pad, const int down_pad, bool pad_grad,
bool input_grad, Tensor* padding_data, Tensor* col) {
auto lod_level_0 = in.lod()[0]; auto lod_level_0 = in.lod()[0];
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf; math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
std::vector<int> dilation({1, 1});
std::vector<int> padding({up_pad, 0, down_pad, 0});
std::vector<int> stride({context_stride, 1});
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
sequence_width = in.dims()[1]; sequence_width = in.dims()[1];
...@@ -215,7 +222,7 @@ class ContextProjectGradFunctor { ...@@ -215,7 +222,7 @@ class ContextProjectGradFunctor {
: static_cast<int>(lod_level_0[i]); : static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]); input_row_end = static_cast<int>(lod_level_0[i + 1]);
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -234,9 +241,7 @@ class ContextProjectGradFunctor { ...@@ -234,9 +241,7 @@ class ContextProjectGradFunctor {
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context, in_t, out_t, col2im_ocf(context, out_t, dilation, stride, padding, &in_t);
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
} }
} }
...@@ -244,7 +249,7 @@ class ContextProjectGradFunctor { ...@@ -244,7 +249,7 @@ class ContextProjectGradFunctor {
if (pad_grad) { if (pad_grad) {
if (padding_trainable) { if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -259,7 +264,7 @@ class ContextProjectGradFunctor { ...@@ -259,7 +264,7 @@ class ContextProjectGradFunctor {
k + context_length < up_pad ? context_length : up_pad - k; k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_t.Slice(k * context_length, Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size); k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size); Tensor w_sub = padding_data->Slice(k, k + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub); auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub); auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e.device(*context.GetEigenDevice<Place>()) =
...@@ -292,7 +297,7 @@ class ContextProjectGradFunctor { ...@@ -292,7 +297,7 @@ class ContextProjectGradFunctor {
Tensor out_t_sub = out_t.Slice( Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size, (down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length); (down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice( Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size); up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub); auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub); auto w_sub_e = EigenMatrix<T>::From(w_sub);
......
...@@ -28,57 +28,55 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -28,57 +28,55 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int stride_height, int stride_width, int padding_up, const std::vector<int>& stride,
int padding_down, int padding_left, int padding_right) { const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col->dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col->dims()[1];
int filter_width = col.dims()[2]; int filter_width = col->dims()[2];
int output_height = col.dims()[3]; int col_height = col->dims()[3];
int output_width = col.dims()[4]; int col_width = col->dims()[4];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(input_height + padding_up + padding_down - filter_height) / ((dilation[0] * (filter_height - 1) + 1))) /
stride_height + stride[0] +
1, 1,
output_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(input_width + padding_left + padding_right - filter_width) / ((dilation[1] * (filter_width - 1) + 1))) /
stride_width + stride[1] +
1, 1,
output_width, col_width,
"output_width and padding(padding_left, padding_right) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up; int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride_width + w_offset - padding_left; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 || col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx >= input_width) { im_col_idx < 0 || im_col_idx >= im_width)
col_data[(c * output_height + h) * output_width + w] = T(0); ? static_cast<T>(0)
} else { : im_data[im_idx];
im_row_idx += c_im * input_height;
col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx];
}
} }
} }
} }
...@@ -94,54 +92,55 @@ template <class T> ...@@ -94,54 +92,55 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int stride_height, const framework::Tensor& col,
int stride_width, int padding_up, int padding_down, const std::vector<int>& dilation,
int padding_left, int padding_right) { const std::vector<int>& stride,
PADDLE_ENFORCE(im.dims().size() == 3); const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im->dims()[0];
int input_height = im.dims()[1]; int im_height = im->dims()[1];
int input_width = im.dims()[2]; int im_width = im->dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int col_height = col.dims()[3];
int output_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(input_height + padding_up + padding_down - filter_height) / ((dilation[0] * (filter_height - 1) + 1))) /
stride_height + stride[0] +
1, 1,
output_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(input_width + padding_left + padding_right - filter_width) / ((dilation[1] * (filter_width - 1) + 1))) /
stride_width + stride[1] +
1, 1,
output_width, col_width,
"output_width and padding(padding_left, padding_right) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
T* im_data = im.data<T>(); T* im_data = im->data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up; int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride_width + w_offset - padding_left; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < input_height && if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < input_width) { (im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_row_idx += c_im * input_height; im_row_idx += c_im * im_height;
im_data[im_row_idx * input_width + im_col_idx] += im_data[im_row_idx * im_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w]; col_data[(c * col_height + h) * col_width + w];
} }
} }
} }
...@@ -168,64 +167,59 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -168,64 +167,59 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int stride_height, int stride_width, int padding_up, const std::vector<int>& stride,
int padding_down, int padding_left, int padding_right) { const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col->dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col->dims()[3];
int filter_width = col.dims()[4]; int filter_width = col->dims()[4];
int output_height = col.dims()[0]; int col_height = col->dims()[0];
int output_width = col.dims()[1]; int col_width = col->dims()[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) / (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
stride_height + col_height,
1,
output_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) / (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent."); "inconsistent.");
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col->data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset = ((((col_row_idx)*output_width + col_col_idx) * int col_offset =
input_channels + ((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) * channel) *
filter_height + filter_height +
filter_row_idx) * filter_row_idx) *
filter_width + filter_width +
filter_col_idx; filter_col_idx;
if (im_row_offset < 0 || im_row_offset >= input_height ||
im_col_offset < 0 || im_col_offset >= input_width) { int im_offset = (channel * im_height + im_row_offset) * im_width +
col_data[col_offset] = T(0);
} else {
int im_offset =
(channel * input_height + im_row_offset) * input_width +
im_col_offset; im_col_offset;
col_data[col_offset] = im_data[im_offset]; col_data[col_offset] =
} (im_row_offset < 0 || im_row_offset >= im_height ||
im_col_offset < 0 || im_col_offset >= im_width)
? static_cast<T>(0)
: im_data[im_offset];
} }
} }
} }
...@@ -243,60 +237,57 @@ template <class T> ...@@ -243,60 +237,57 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int stride_height, const framework::Tensor& col,
int stride_width, int padding_up, int padding_down, const std::vector<int>& dilation,
int padding_left, int padding_right) { const std::vector<int>& stride,
PADDLE_ENFORCE(im.dims().size() == 3); const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im->dims()[0];
int input_height = im.dims()[1]; int im_height = im->dims()[1];
int input_width = im.dims()[2]; int im_width = im->dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int col_height = col.dims()[0];
int output_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) / (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
stride_height + col_height,
1,
output_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) / (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent."); "inconsistent.");
T* im_data = im.data<T>(); T* im_data = im->data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset = (((col_row_idx * output_width + col_col_idx) * int col_offset =
input_channels + (((col_row_idx * col_width + col_col_idx) * im_channels +
channel) * channel) *
filter_height + filter_height +
filter_row_idx) * filter_row_idx) *
filter_width + filter_width +
filter_col_idx; filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < input_height && if (im_row_offset >= 0 && im_row_offset < im_height &&
im_col_offset >= 0 && im_col_offset < input_width) { im_col_offset >= 0 && im_col_offset < im_width) {
int im_offset = int im_offset =
(channel * input_height + im_row_offset) * input_width + (channel * im_height + im_row_offset) * im_width +
im_col_offset; im_col_offset;
im_data[im_offset] += col_data[col_offset]; im_data[im_offset] += col_data[col_offset];
} }
......
此差异已折叠。
...@@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; ...@@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
* \param colData Column data. * \param colData Column data.
* \param colShape The shape of colData. * \param colShape The shape of colData.
* *
* \param dilations dilation data.
* \param 2-dimension [dilation_height, dilation_width].
*
* \param strides stride data.
* \param 2-dimension [stride_height, stride_width].
*
* \param paddings padding data.
* \param 4-dimension [up_pad, left_pad, down_pad, right_pad].
*
* If the template argument Format is kCFO, the shape of colData is: * If the template argument Format is kCFO, the shape of colData is:
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
* So, it is easy to reshape into a convolution matrix for convolution * So, it is easy to reshape into a convolution matrix for convolution
...@@ -73,18 +82,19 @@ template <ColFormat Format, typename Place, typename T> ...@@ -73,18 +82,19 @@ template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor { class Im2ColFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int stride_height, int stride_width, int padding_up, const std::vector<int>& stride,
int padding_down, int padding_left, int padding_right); const std::vector<int>& padding, framework::Tensor* col);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int stride_height, const framework::Tensor& col,
int stride_width, int padding_up, int padding_down, const std::vector<int>& dilation,
int padding_left, int padding_right); const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im);
}; };
} // namespace math } // namespace math
......
...@@ -45,10 +45,14 @@ void testIm2col() { ...@@ -45,10 +45,14 @@ void testIm2col() {
int input_height = 2; int input_height = 2;
int input_width = 3; int input_width = 3;
int filter_size = 2; int filter_size = 2;
int stride = 1; std::vector<int> stride({1, 1}); // stride_y, stride_x
int padding = 0; std::vector<int> padding(
int output_height = (input_height - filter_size + 2 * padding) / stride + 1; {0, 0, 0, 0}); // up_pad, left_pad, down_pad, right_pad
int output_width = (input_width - filter_size + 2 * padding) / stride + 1; std::vector<int> dilation({1, 1}); // dilation_y, dilation_x
int output_height =
(input_height - filter_size + padding[0] + padding[1]) / stride[0] + 1;
int output_width =
(input_width - filter_size + padding[2] + padding[3]) / stride[1] + 1;
float* input_ptr = input_tmp.mutable_data<float>( float* input_ptr = input_tmp.mutable_data<float>(
{1, input_height, input_width}, paddle::platform::CPUPlace()); {1, input_height, input_width}, paddle::platform::CPUPlace());
float arr[6] = {0, 1, 2, 3, 4, 5}; float arr[6] = {0, 1, 2, 3, 4, 5};
...@@ -85,10 +89,8 @@ void testIm2col() { ...@@ -85,10 +89,8 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float> paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf; im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, im2col(*context, input, dilation, stride, padding, &output_cfo);
padding); im2col_ocf(*context, input, dilation, stride, padding, &output_ocf);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding,
padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
...@@ -131,8 +133,7 @@ void testIm2col() { ...@@ -131,8 +133,7 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, col2im(*context, output_cfo, dilation, stride, padding, &input);
padding);
float* in_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
...@@ -153,8 +154,7 @@ void testIm2col() { ...@@ -153,8 +154,7 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, col2im_ocf(*context, output_ocf, dilation, stride, padding, &input);
padding, padding);
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
......
...@@ -28,28 +28,51 @@ template <class T> ...@@ -28,28 +28,51 @@ template <class T>
class Vol2ColFunctor<platform::CPUPlace, T> { class Vol2ColFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col, const framework::Tensor& vol,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& dilations,
int padding_depth, int padding_height, const std::vector<int>& strides,
int padding_width) const { const std::vector<int>& paddings,
framework::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4); PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col->dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol.dims()[1];
int input_height = vol.dims()[2]; int input_height = vol.dims()[2];
int input_width = vol.dims()[3]; int input_width = vol.dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col->dims()[1];
int filter_height = col.dims()[2]; int filter_height = col->dims()[2];
int filter_width = col.dims()[3]; int filter_width = col->dims()[3];
int output_depth = col.dims()[4]; int output_depth = col->dims()[4];
int output_height = col.dims()[5]; int output_height = col->dims()[5];
int output_width = col.dims()[6]; int output_width = col->dims()[6];
int channels_col = int channels_col =
input_channels * filter_depth * filter_height * filter_width; input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"mismatching.");
const T* vol_data = vol.data<T>(); const T* vol_data = vol.data<T>();
T* col_data = col.data<T>(); T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
...@@ -57,24 +80,23 @@ class Vol2ColFunctor<platform::CPUPlace, T> { ...@@ -57,24 +80,23 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth; int d_offset = (c / filter_width / filter_height) % filter_depth;
int c_in = c / filter_width / filter_height / filter_depth; int c_in = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) { for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset; int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < output_height; ++h) {
int h_pad = h * stride_height - padding_height + h_offset; int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int w_pad = w * stride_width - padding_width + w_offset; int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
int col_idx = int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w; ((c * output_depth + d) * output_height + h) * output_width + w;
if (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) {
col_data[col_idx] = static_cast<T>(0);
} else {
int vol_idx = int vol_idx =
((c_in * input_depth + d_pad) * input_height + h_pad) * ((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width + input_width +
w_pad; w_pad;
col_data[col_idx] = vol_data[vol_idx]; col_data[col_idx] =
} (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth)
? static_cast<T>(0)
: vol_data[vol_idx];
} }
} }
} }
...@@ -92,17 +114,18 @@ template <class T> ...@@ -92,17 +114,18 @@ template <class T>
class Col2VolFunctor<platform::CPUPlace, T> { class Col2VolFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col, const framework::Tensor& col,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& dilations,
int padding_depth, int padding_height, const std::vector<int>& strides,
int padding_width) const { const std::vector<int>& paddings,
PADDLE_ENFORCE(vol.dims().size() == 4); framework::Tensor* vol) const {
PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol->dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol->dims()[1];
int input_height = vol.dims()[2]; int input_height = vol->dims()[2];
int input_width = vol.dims()[3]; int input_width = vol->dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col.dims()[1];
int filter_height = col.dims()[2]; int filter_height = col.dims()[2];
int filter_width = col.dims()[3]; int filter_width = col.dims()[3];
...@@ -112,7 +135,28 @@ class Col2VolFunctor<platform::CPUPlace, T> { ...@@ -112,7 +135,28 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int channels_col = int channels_col =
input_channels * filter_depth * filter_height * filter_width; input_channels * filter_depth * filter_height * filter_width;
T* vol_data = vol.data<T>(); PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"mismatching.");
T* vol_data = vol->data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
...@@ -121,11 +165,11 @@ class Col2VolFunctor<platform::CPUPlace, T> { ...@@ -121,11 +165,11 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth; int d_offset = (c / filter_width / filter_height) % filter_depth;
int cIm = c / filter_width / filter_height / filter_depth; int cIm = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) { for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset; int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < output_height; ++h) {
int h_pad = h * stride_height - padding_height + h_offset; int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int w_pad = w * stride_width - padding_width + w_offset; int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
...@@ -133,6 +177,7 @@ class Col2VolFunctor<platform::CPUPlace, T> { ...@@ -133,6 +177,7 @@ class Col2VolFunctor<platform::CPUPlace, T> {
((cIm * input_depth + d_pad) * input_height + h_pad) * ((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width + input_width +
w_pad; w_pad;
int col_idx = int col_idx =
((c * output_depth + d) * output_height + h) * output_width + ((c * output_depth + d) * output_height + h) * output_width +
w; w;
......
...@@ -21,11 +21,12 @@ namespace math { ...@@ -21,11 +21,12 @@ namespace math {
template <class T> template <class T>
__global__ void vol2col(int num_kernels, const T* data_vol, int depth, __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
int height, int width, int filter_depth, int height, int width, int dilation_d, int dilation_h,
int filter_height, int filter_width, int stride_depth, int dilation_w, int filter_depth, int filter_height,
int stride_height, int stride_width, int padding_depth, int filter_width, int stride_depth, int stride_height,
int padding_height, int padding_width, int output_detph, int stride_width, int padding_depth, int padding_height,
int output_height, int output_width, T* data_col) { int padding_width, int output_detph, int output_height,
int output_width, T* data_col) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int w_out = index % output_width; int w_out = index % output_width;
...@@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, ...@@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
for (int k = 0; k < filter_depth; ++k) { for (int k = 0; k < filter_depth; ++k) {
for (int i = 0; i < filter_height; ++i) { for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) { for (int j = 0; j < filter_width; ++j) {
int d = d_in + k; int d = d_in + k * dilation_d;
int h = h_in + i; int h = h_in + i * dilation_h;
int w = w_in + j; int w = w_in + j * dilation_w;
int col_idx = (k * dilation_d * height + i * dilation_h) * width +
j * dilation_w;
*data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
w < width) w < width)
? data_vol[(k * height + i) * width + j] ? data_vol[col_idx]
: 0; : 0;
data_col += output_detph * output_height * output_width; data_col += output_detph * output_height * output_width;
} }
...@@ -68,23 +71,46 @@ template <class T> ...@@ -68,23 +71,46 @@ template <class T>
class Vol2ColFunctor<platform::GPUPlace, T> { class Vol2ColFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col, const framework::Tensor& vol,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& dilations,
int padding_depth, int padding_height, const std::vector<int>& strides,
int padding_width) const { const std::vector<int>& paddings,
framework::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4); PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col->dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol.dims()[1];
int input_height = vol.dims()[2]; int input_height = vol.dims()[2];
int input_width = vol.dims()[3]; int input_width = vol.dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col->dims()[1];
int filter_height = col.dims()[2]; int filter_height = col->dims()[2];
int filter_width = col.dims()[3]; int filter_width = col->dims()[3];
int output_depth = col.dims()[4]; int output_depth = col->dims()[4];
int output_height = col.dims()[5]; int output_height = col->dims()[5];
int output_width = col.dims()[6]; int output_width = col->dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"Mismatching.");
int num_outputs = int num_outputs =
input_channels * output_depth * output_height * output_width; input_channels * output_depth * output_height * output_width;
...@@ -95,19 +121,25 @@ class Vol2ColFunctor<platform::GPUPlace, T> { ...@@ -95,19 +121,25 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width, num_outputs, vol.data<T>(), input_depth, input_height, input_width,
filter_depth, filter_height, filter_width, stride_depth, stride_height, dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
stride_width, padding_depth, padding_height, padding_width, filter_width, strides[0], strides[1], strides[2], paddings[0],
output_depth, output_height, output_width, col.data<T>()); paddings[1], paddings[2], output_depth, output_height, output_width,
col->data<T>());
} }
}; };
template <class T> template <class T>
__global__ void col2vol(int num_kernels, const T* data_col, int depth, __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int height, int width, int filter_depth, int height, int width, int dilation_d, int dilation_h,
int filter_height, int filter_width, int stride_depth, int dilation_w, int filter_depth, int filter_height,
int stride_height, int stride_width, int padding_depth, int filter_width, int stride_depth, int stride_height,
int padding_height, int padding_width, int output_detph, int stride_width, int padding_depth, int padding_height,
int output_height, int output_width, T* data_vol) { int padding_width, int output_detph, int output_height,
int output_width, T* data_vol) {
const int d_filter_depth = dilation_d * (filter_depth - 1) + 1;
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
T src_val = 0; T src_val = 0;
...@@ -115,35 +147,41 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, ...@@ -115,35 +147,41 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int h = (index / width) % height + padding_height; int h = (index / width) % height + padding_height;
int d = (index / width / height) % depth + padding_depth; int d = (index / width / height) % depth + padding_depth;
int c = index / width / height / depth; int c = index / width / height / depth;
// compute the start and end of the output // compute the start and end of the output
int w_col_start = int w_col_start =
(w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
int w_col_end = min(w / stride_width + 1, output_width); int w_col_end = min(w / stride_width + 1, output_width);
int h_col_start = int h_col_start =
(h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
int h_col_end = min(h / stride_height + 1, output_height); int h_col_end = min(h / stride_height + 1, output_height);
int d_col_start = int d_col_start =
(d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1;
int d_col_end = min(d / stride_depth + 1, output_detph); int d_col_end = min(d / stride_depth + 1, output_detph);
int offset = (c * filter_depth * filter_height * filter_width +
d * filter_width * filter_height + h * filter_width + w) *
output_detph * output_height * output_width;
int coeff_d_col =
(1 - stride_depth * filter_width * filter_height * output_detph) *
output_height * output_width;
int coeff_h_col =
(1 - stride_height * filter_width * output_detph * output_height) *
output_width;
int coeff_w_col =
(1 - stride_width * output_detph * output_height * output_width);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
src_val += data_col[offset + d_col * coeff_d_col + int d_off = (d - d_col * stride_depth);
h_col * coeff_h_col + w_col * coeff_w_col]; int h_off = (h - h_col * stride_height);
int w_off = (w - w_col * stride_width);
if (d_off % dilation_d == 0 && h_off % dilation_h == 0 &&
w_off % dilation_w == 0) {
d_off /= dilation_d;
h_off /= dilation_h;
w_off /= dilation_w;
int data_col_index =
(((((c * filter_depth + d_off) * filter_height + h_off) *
filter_width +
w_off)));
data_col_index =
((data_col_index * output_detph + d_col) * output_height +
h_col) *
output_width +
w_col;
src_val += data_col[data_col_index];
}
} }
} }
} }
...@@ -161,17 +199,18 @@ template <class T> ...@@ -161,17 +199,18 @@ template <class T>
class Col2VolFunctor<platform::GPUPlace, T> { class Col2VolFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col, const framework::Tensor& col,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& dilations,
int padding_depth, int padding_height, const std::vector<int>& strides,
int padding_width) const { const std::vector<int>& paddings,
PADDLE_ENFORCE(vol.dims().size() == 4); framework::Tensor* vol) const {
PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol->dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol->dims()[1];
int input_height = vol.dims()[2]; int input_height = vol->dims()[2];
int input_width = vol.dims()[3]; int input_width = vol->dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col.dims()[1];
int filter_height = col.dims()[2]; int filter_height = col.dims()[2];
int filter_width = col.dims()[3]; int filter_width = col.dims()[3];
...@@ -179,6 +218,28 @@ class Col2VolFunctor<platform::GPUPlace, T> { ...@@ -179,6 +218,28 @@ class Col2VolFunctor<platform::GPUPlace, T> {
int output_height = col.dims()[5]; int output_height = col.dims()[5];
int output_width = col.dims()[6]; int output_width = col.dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"Mismatching.");
int num_kernels = input_channels * input_depth * input_height * input_width; int num_kernels = input_channels * input_depth * input_height * input_width;
const int threads = 1024; const int threads = 1024;
...@@ -188,9 +249,10 @@ class Col2VolFunctor<platform::GPUPlace, T> { ...@@ -188,9 +249,10 @@ class Col2VolFunctor<platform::GPUPlace, T> {
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width, num_kernels, col.data<T>(), input_depth, input_height, input_width,
filter_depth, filter_height, filter_width, stride_depth, stride_height, dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
stride_width, padding_depth, padding_height, padding_width, filter_width, strides[0], strides[1], strides[2], paddings[0],
output_depth, output_height, output_width, vol.data<T>()); paddings[1], paddings[2], output_depth, output_height, output_width,
vol->data<T>());
} }
}; };
......
...@@ -31,6 +31,15 @@ namespace math { ...@@ -31,6 +31,15 @@ namespace math {
* \param colData Column data. * \param colData Column data.
* \param colShape The shape of colData. * \param colShape The shape of colData.
* *
* \param dilations dilation data.
* \param 3-dimension [dilation_depth, dilation_height, dilation_width].
*
* \param strides stride data.
* \param 3-dimension [stride_depth, stride_height, stride_width].
*
* \param paddings padding data.
* \param 3-dimension [d_pad, h_pad, w_pad].
*
* The shape of colData is: * The shape of colData is:
* [input_channels, filter_depth, filter_height, filter_width, output_depth, * [input_channels, filter_depth, filter_height, filter_width, output_depth,
* output_height, output_width] * output_height, output_width]
...@@ -57,20 +66,22 @@ template <typename Place, typename T> ...@@ -57,20 +66,22 @@ template <typename Place, typename T>
class Vol2ColFunctor { class Vol2ColFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col, const framework::Tensor& vol,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& dilations,
int padding_depth, int padding_height, const std::vector<int>& strides,
int padding_width) const; const std::vector<int>& paddings,
framework::Tensor* col) const;
}; };
template <typename Place, typename T> template <typename Place, typename T>
class Col2VolFunctor { class Col2VolFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col, const framework::Tensor& col,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& dilations,
int padding_depth, int padding_height, const std::vector<int>& strides,
int padding_width) const; const std::vector<int>& paddings,
framework::Tensor* vol) const;
}; };
} // namespace math } // namespace math
......
...@@ -62,11 +62,15 @@ void testVol2col() { ...@@ -62,11 +62,15 @@ void testVol2col() {
int input_height = 2; int input_height = 2;
int input_width = 3; int input_width = 3;
int filter_size = 2; int filter_size = 2;
int stride = 1; std::vector<int> strides({1, 1, 1});
int padding = 0; std::vector<int> paddings({0, 0, 0});
int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; std::vector<int> dilations({1, 1, 1});
int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_depth =
int output_width = (input_width - filter_size + 2 * padding) / stride + 1; (input_depth - filter_size + 2 * paddings[0]) / strides[0] + 1;
int output_height =
(input_height - filter_size + 2 * paddings[1]) / strides[1] + 1;
int output_width =
(input_width - filter_size + 2 * paddings[2]) / strides[2] + 1;
// Vol2Col test // Vol2Col test
float* input_ptr = float* input_ptr =
...@@ -85,8 +89,7 @@ void testVol2col() { ...@@ -85,8 +89,7 @@ void testVol2col() {
*place); *place);
paddle::operators::math::Vol2ColFunctor<Place, float> vol2col; paddle::operators::math::Vol2ColFunctor<Place, float> vol2col;
vol2col(*context, input, output, stride, stride, stride, padding, padding, vol2col(*context, input, dilations, strides, paddings, &output);
padding);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
float* out_cfo_ptr; float* out_cfo_ptr;
...@@ -111,8 +114,7 @@ void testVol2col() { ...@@ -111,8 +114,7 @@ void testVol2col() {
} }
paddle::operators::math::Col2VolFunctor<Place, float> col2vol; paddle::operators::math::Col2VolFunctor<Place, float> col2vol;
col2vol(*context, input, output, stride, stride, stride, padding, padding, col2vol(*context, output, dilations, strides, paddings, &input);
padding);
float* in_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
......
...@@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel<T> { ...@@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel<T> {
math::ContextProjectFunctor<Place, T> seq_project_functor; math::ContextProjectFunctor<Place, T> seq_project_functor;
seq_project_functor(context.device_context(), *in, *padding_data, col, seq_project_functor(context.device_context(), *in, *padding_data,
padding_trainable, context_start, context_length, padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad); context_stride, up_pad, down_pad, &col);
math::matmul<Place, T>(context.device_context(), col, false, filter, false, math::matmul<Place, T>(context.device_context(), col, false, filter, false,
static_cast<T>(1.0), out, static_cast<T>(0.0)); static_cast<T>(1.0), out, static_cast<T>(0.0));
...@@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
in_g->set_lod(in->lod()); in_g->set_lod(in->lod());
set_zero(context.device_context(), in_g, static_cast<T>(0)); set_zero(context.device_context(), in_g, static_cast<T>(0));
seq_project_grad_functor(context.device_context(), *in_g, *padding_data_g, seq_project_grad_functor(context.device_context(), *in_g,
col, padding_trainable, context_start, padding_trainable, context_start, context_length,
context_length, context_stride, up_pad, down_pad, context_stride, up_pad, down_pad, false, true,
true, false); padding_data_g, &col);
} }
if (padding_trainable && padding_data_g) { if (padding_trainable && padding_data_g) {
...@@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
LoDTensor* input = const_cast<LoDTensor*>(in); LoDTensor* input = const_cast<LoDTensor*>(in);
seq_project_grad_functor(context.device_context(), *input, seq_project_grad_functor(context.device_context(), *input,
*padding_data_g, col, padding_trainable, padding_trainable, context_start, context_length,
context_start, context_length, context_stride, context_stride, up_pad, down_pad, true, false,
up_pad, down_pad, false, true); padding_data_g, &col);
} }
if (filter_g) { if (filter_g) {
...@@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
padding_data = context.Input<Tensor>("PaddingData"); padding_data = context.Input<Tensor>("PaddingData");
} }
seq_project_functor(context.device_context(), *in, *padding_data, col, seq_project_functor(context.device_context(), *in, *padding_data,
padding_trainable, context_start, context_length, padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad); context_stride, up_pad, down_pad, &col);
math::matmul<Place, T>(context.device_context(), col, true, out_grad, math::matmul<Place, T>(context.device_context(), col, true, out_grad,
false, T(1.0), &filter_grad, T(1.0)); false, T(1.0), &filter_grad, T(1.0));
......
...@@ -10,23 +10,33 @@ def conv2d_forward_naive(input, filter, group, conv_param): ...@@ -10,23 +10,33 @@ def conv2d_forward_naive(input, filter, group, conv_param):
assert np.mod(out_c, group) == 0 assert np.mod(out_c, group) == 0
sub_out_c = out_c / group sub_out_c = out_c / group
stride, pad = conv_param['stride'], conv_param['pad'] stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[
out_h = 1 + (in_h + 2 * pad[0] - f_h) / stride[0] 'dilation']
out_w = 1 + (in_w + 2 * pad[1] - f_w) / stride[1] out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) / stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) / stride[1]
out = np.zeros((in_n, out_c, out_h, out_w)) out = np.zeros((in_n, out_c, out_h, out_w))
d_bolck_w = (dilation[0] * (f_h - 1) + 1)
d_bolck_h = (dilation[1] * (f_w - 1) + 1)
input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )), input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )),
mode='constant', mode='constant',
constant_values=0) constant_values=0)
filter_dilation = np.zeros((out_c, f_c, d_bolck_h, d_bolck_w))
filter_dilation[:, :, 0:d_bolck_h:dilation[0], 0:d_bolck_w:dilation[
1]] = filter
for i in range(out_h): for i in range(out_h):
for j in range(out_w): for j in range(out_w):
for g in range(group): for g in range(group):
input_pad_masked = \ input_pad_masked = \
input_pad[:, g * f_c:(g + 1) * f_c, input_pad[:, g * f_c:(g + 1) * f_c,
i * stride[0]:i * stride[0] + f_h, i * stride[0]:i * stride[0] + d_bolck_h,
j * stride[1]:j * stride[1] + f_w] j * stride[1]:j * stride[1] + d_bolck_w]
f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :] f_sub = filter_dilation[g * sub_out_c:(g + 1) *
sub_out_c, :, :, :]
for k in range(sub_out_c): for k in range(sub_out_c):
out[:, g * sub_out_c + k, i, j] = \ out[:, g * sub_out_c + k, i, j] = \
np.sum(input_pad_masked * f_sub[k, :, :, :], np.sum(input_pad_masked * f_sub[k, :, :, :],
...@@ -39,9 +49,14 @@ class TestConv2dOp(OpTest): ...@@ -39,9 +49,14 @@ class TestConv2dOp(OpTest):
def setUp(self): def setUp(self):
self.init_op_type() self.init_op_type()
self.init_group() self.init_group()
self.init_dilation()
self.init_test_case() self.init_test_case()
conv2d_param = {'stride': self.stride, 'pad': self.pad} conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype("float32") input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32") filter = np.random.random(self.filter_size).astype("float32")
output = conv2d_forward_naive(input, filter, self.groups, output = conv2d_forward_naive(input, filter, self.groups,
...@@ -80,12 +95,14 @@ class TestConv2dOp(OpTest): ...@@ -80,12 +95,14 @@ class TestConv2dOp(OpTest):
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0 assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3] self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
...@@ -101,24 +118,66 @@ class TestWithGroup(TestConv2dOp): ...@@ -101,24 +118,66 @@ class TestWithGroup(TestConv2dOp):
self.op_type = "conv2d" self.op_type = "conv2d"
#----------------Conv2dCudnn---------------- class TestWith1x1(TestConv2dOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 1, 1]
def init_dilation(self):
self.dilations = [1, 1]
class TestCudnn(TestConv2dOp):
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 3
def init_op_type(self): def init_op_type(self):
self.op_type = "conv_cudnn" self.op_type = "conv2d"
class TestCudnnWithGroup(TestConv2dOp): class TestWithDilation(TestConv2dOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 10, 10] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [2, 2]
def init_group(self): def init_group(self):
self.groups = 3 self.groups = 3
def init_op_type(self):
self.op_type = "conv2d"
#----------------Conv2dCudnn----------------
class TestCudnn(TestConv2dOp):
def init_op_type(self):
self.op_type = "conv_cudnn"
class TestCudnnWithGroup(TestWithGroup):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv_cudnn" self.op_type = "conv_cudnn"
class TestCudnnWith1x1(TestWith1x1):
def init_op_type(self):
self.op_type = "conv_cudnn"
# cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv_cudnn"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -10,26 +10,39 @@ def conv3d_forward_naive(input, filter, group, conv_param): ...@@ -10,26 +10,39 @@ def conv3d_forward_naive(input, filter, group, conv_param):
assert np.mod(out_c, group) == 0 assert np.mod(out_c, group) == 0
sub_out_c = out_c / group sub_out_c = out_c / group
stride, pad = conv_param['stride'], conv_param['pad'] stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[
out_d = 1 + (in_d + 2 * pad[0] - f_h) / stride[0] 'dilations']
out_h = 1 + (in_h + 2 * pad[1] - f_h) / stride[1]
out_w = 1 + (in_w + 2 * pad[2] - f_w) / stride[2] out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) / stride[0]
out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) / stride[1]
out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) / stride[2]
out = np.zeros((in_n, out_c, out_d, out_h, out_w)) out = np.zeros((in_n, out_c, out_d, out_h, out_w))
d_bolck_d = (dilation[0] * (f_d - 1) + 1)
d_bolck_h = (dilation[1] * (f_h - 1) + 1)
d_bolck_w = (dilation[2] * (f_w - 1) + 1)
input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ), input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ),
(pad[2], )), (pad[2], )),
mode='constant', mode='constant',
constant_values=0) constant_values=0)
filter_dilation = np.zeros((out_c, f_c, d_bolck_d, d_bolck_h, d_bolck_w))
filter_dilation[:, :, 0:d_bolck_d:dilation[0], 0:d_bolck_h:dilation[1], 0:
d_bolck_w:dilation[2]] = filter
for d in range(out_d): for d in range(out_d):
for i in range(out_h): for i in range(out_h):
for j in range(out_w): for j in range(out_w):
for g in range(group): for g in range(group):
input_pad_masked = \ input_pad_masked = \
input_pad[:, g * f_c:(g + 1) * f_c, input_pad[:, g * f_c:(g + 1) * f_c,
d * stride[0]:d * stride[0] + f_d, d * stride[0]:d * stride[0] + d_bolck_d,
i * stride[1]:i * stride[1] + f_h, i * stride[1]:i * stride[1] + d_bolck_h,
j * stride[2]:j * stride[2] + f_w] j * stride[2]:j * stride[2] + d_bolck_w]
f_sub = filter[g * sub_out_c:(g + 1) *
f_sub = filter_dilation[g * sub_out_c:(g + 1) *
sub_out_c, :, :, :, :] sub_out_c, :, :, :, :]
for k in range(sub_out_c): for k in range(sub_out_c):
out[:, g * sub_out_c + k, d, i, j] = \ out[:, g * sub_out_c + k, d, i, j] = \
...@@ -43,9 +56,14 @@ class TestConv3dOp(OpTest): ...@@ -43,9 +56,14 @@ class TestConv3dOp(OpTest):
def setUp(self): def setUp(self):
self.init_group() self.init_group()
self.init_op_type() self.init_op_type()
self.init_dilation()
self.init_test_case() self.init_test_case()
conv3d_param = {'stride': self.stride, 'pad': self.pad} conv3d_param = {
'stride': self.stride,
'pad': self.pad,
'dilations': self.dilations
}
input = np.random.random(self.input_size).astype("float32") input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32") filter = np.random.random(self.filter_size).astype("float32")
output = conv3d_forward_naive(input, filter, self.groups, output = conv3d_forward_naive(input, filter, self.groups,
...@@ -55,7 +73,8 @@ class TestConv3dOp(OpTest): ...@@ -55,7 +73,8 @@ class TestConv3dOp(OpTest):
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
'groups': self.groups 'groups': self.groups,
'dilations': self.dilations
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
...@@ -88,6 +107,9 @@ class TestConv3dOp(OpTest): ...@@ -88,6 +107,9 @@ class TestConv3dOp(OpTest):
f_c = self.input_size[1] / self.groups f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3, 3] self.filter_size = [6, f_c, 3, 3, 3]
def init_dilation(self):
self.dilations = [1, 1, 1]
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
...@@ -104,27 +126,47 @@ class TestCase1(TestConv3dOp): ...@@ -104,27 +126,47 @@ class TestCase1(TestConv3dOp):
f_c = self.input_size[1] / self.groups f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3, 3] self.filter_size = [6, f_c, 3, 3, 3]
def init_group(self):
self.groups = 1
def init_op_type(self): class TestWithGroup1(TestConv3dOp):
self.op_type = "conv3d" def init_group(self):
self.groups = 3
class TestWithGroup1(TestConv3dOp): class TestWithGroup2(TestCase1):
def init_group(self): def init_group(self):
self.groups = 3 self.groups = 3
def init_op_type(self):
self.op_type = "conv3d"
class TestWith1x1(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 4, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 1, 1, 1]
def init_dilation(self):
self.dilations = [1, 1, 1]
class TestWithGroup2(TestCase1):
def init_group(self): def init_group(self):
self.groups = 3 self.groups = 3
def init_op_type(self):
self.op_type = "conv3d" class TestWithDilation(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 6, 6, 6] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 2, 2, 2]
def init_dilation(self):
self.dilations = [2, 2, 2]
def init_group(self):
self.groups = 3
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册