提交 356d6954 编写于 作者: C chengduoZH

follow comments

上级 7d73b8fc
......@@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/operators/conv_op.h"
#include <vector>
namespace paddle {
namespace operators {
......@@ -53,7 +54,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"The number of output channels should be divided by groups.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) {
for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] -
(dilations[i] * (filter_dims[i + 2] - 1) + 1) >
0,
......@@ -61,8 +62,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"dilations, the output size is less than 0, please check "
"again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], paddings[i],
strides[i]));
dilations[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
......@@ -86,9 +86,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0, 0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator.")
.SetDefault({0, 0});
AddAttr<int>(
"groups",
......@@ -99,9 +105,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1}), the dilations of "
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault(std::vector<int>{1, 1});
.SetDefault({1, 1});
AddComment(R"DOC(
Convolution Operator.
......@@ -147,13 +154,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>(
"strides",
"(vector, default:{0, 0, 0}), the strides of convolution operator.")
AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>(
"paddings",
"(vector, default:{0, 0, 0}), the paddings of convolution operator.")
AddAttr<std::vector<int>>("paddings",
"(vector<int>, default:{0, 0, 0}), the "
"paddings(d_pad, h_pad, w_pad) of convolution "
"operator.")
.SetDefault({0, 0, 0});
AddAttr<int>(
"groups",
......@@ -164,10 +173,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1, 1}), the dilations of "
"(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation, h_dilation, w_dilation) of "
"convolution operator. Currently, conv3d doesn't "
"support dilation.")
.SetDefault(std::vector<int>{1, 1, 1});
.SetDefault({1, 1, 1});
AddComment(R"DOC(
Convolution3D Operator.
......
......@@ -28,24 +28,22 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int dilation,
int padding_up, int padding_down, int stride) {
int output_size = (input_size + padding_up + padding_down -
(dilation * (filter_size - 1) + 1)) /
stride +
1;
int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
return output_size;
}
inline bool NotExpand(std::vector<int64_t>& filter_dim,
std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& dilations) {
inline bool IsExpand(std::vector<int64_t>& filter_dim,
std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 &= (static_cast<int>(filter_dim[j]) == 1);
strides_1 &= (strides[j] == 1);
padding_0 &= (paddings[j] == 0);
dilation_1 &= (dilations[j] == 1);
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
return filter_1 && strides_1 && padding_0 && dilation_1;
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
// Define Op classes in .h file so that other conv
......@@ -65,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
class ConvOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
......@@ -88,9 +84,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -122,13 +118,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix;
if (!not_expand) {
if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace());
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
......@@ -149,51 +145,37 @@ class GemmConvKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
if (!not_expand) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
math::Vol2ColFunctor<Place, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
if (filter_shape_vec.size() == 2) {
// im2col
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
// vol2col
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], dilations[2], strides[0], strides[1],
strides[2], paddings[0], paddings[1], paddings[2]);
}
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
}
}
} else {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
} else if (filter_shape_vec.size() == 2) {
// im2col
im2col(context.device_context(), in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) {
// vol2col
vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
}
}
}
......@@ -217,9 +199,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (!input_grad && !filter_grad) return;
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -270,13 +252,13 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix;
if (!not_expand) {
if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace());
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
......@@ -288,61 +270,38 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
if (!not_expand) {
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice =
filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), in_grad_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Col2VolFunctor<Place, T> col2vol;
col2vol(context.device_context(), in_grad_slice, col,
dilations[0], dilations[1], dilations[2], strides[0],
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
}
}
}
} else {
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice =
filter.Slice(g * out_step, (g + 1) * out_step);
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
math::Col2VolFunctor<Place, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape);
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
}
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
if (is_expand && filter_shape_vec.size() == 2) {
col2im(context.device_context(), col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&in_grad_slice);
} else if (is_expand && filter_shape_vec.size() == 3) {
col2vol(context.device_context(), col, dilations, strides, paddings,
&in_grad_slice);
}
}
}
......@@ -353,60 +312,38 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
set_zero(context.device_context(), filter_grad, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!not_expand) {
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], dilations[2], strides[0], strides[1],
strides[2], paddings[0], paddings[1], paddings[2]);
}
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
} else {
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
} else if (filter_shape_vec.size() == 2) {
im2col(context.device_context(), in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) {
vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
}
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
}
......
......@@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"as the number of filters.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
for (size_t i = 0; i < paddings.size(); ++i) {
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] +
filter_dims[i + 2]);
}
......@@ -77,13 +77,14 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>(
"strides",
"(vector defalut:{1, 1}), strides of convolution transpose operator.")
AddAttr<std::vector<int>>("strides",
"(vector<int> defalut:{1, 1}), strides of "
"convolution transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>(
"paddings",
"(vector defalut:{0, 0}), paddings of convolution transpose operator.")
"(vector<int> defalut:{0, 0}), paddings(h_pad, w_pad) of convolution "
"transpose operator.")
.SetDefault({0, 0});
AddComment(R"DOC(
Convolution2D Transpose Operator.
......@@ -132,13 +133,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
"Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature.");
AddAttr<std::vector<int>>(
"strides",
"(vector defalut:{1, 1, 1}), strides of convolution transpose operator.")
AddAttr<std::vector<int>>("strides",
"(vector<int> defalut:{1, 1, 1}), strides of "
"convolution transpose operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>(
"paddings",
"(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.")
AddAttr<std::vector<int>>("paddings",
"(vector<int> defalut:{0, 0, 0}), paddings(d_pad, "
"h_pad, w_pad) of convolution transpose operator.")
.SetDefault({0, 0, 0});
AddComment(R"DOC(
Convolution3D Transpose Operator.
......
......@@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvTransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override;
};
......@@ -66,13 +62,11 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
int dilaiton_d = 1;
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w}
......@@ -124,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), output, static_cast<T>(0));
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Col2VolFunctor<Place, T> col2vol;
std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input)
for (int i = 0; i < batch_size; i++) {
......@@ -142,17 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) {
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), output_batch, col, dilation_h,
dilation_w, strides[0], strides[1], 0, 0, 0, 0);
col2im(context.device_context(), col,
std::vector<int>{dilations[0], dilations[1]}, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&output_batch);
} else if (filter_shape_vec.size() == 3) {
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
math::Col2VolFunctor<Place, T> col2vol;
col2vol(context.device_context(), output_batch, col, dilaiton_d,
dilation_h, dilation_w, strides[0], strides[1], strides[2], 0,
0, 0);
col2vol(context.device_context(), col, dilations, strides,
std::vector<int>{0, 0, 0}, &output_batch);
}
}
}
......@@ -179,10 +176,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int dilaiton_d = 1;
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w}
......@@ -237,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor filter_grad_;
math::SetConstant<Place, T> set_zero;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
std::vector<int> dilations({1, 1, 1});
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
......@@ -256,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) {
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), output_grad_batch, col, dilation_h,
dilation_w, strides[0], strides[1], paddings[0], paddings[0],
paddings[1], paddings[1]);
im2col(context.device_context(), output_grad_batch,
std::vector<int>{dilations[0], dilations[1]}, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) {
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), output_grad_batch, col, dilaiton_d,
dilation_h, dilation_w, strides[0], strides[1], strides[2],
paddings[0], paddings[1], paddings[2]);
vol2col(context.device_context(), output_grad_batch, dilations,
strides, paddings, &col);
}
if (input_grad) {
......
......@@ -95,8 +95,9 @@ class ContextProjectFunctor {
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
int dilation_h = 1;
int dilation_w = 1;
std::vector<int> dilation({1, 1});
std::vector<int> padding({up_pad, 0, down_pad, 0});
std::vector<int> stride({context_stride, 1});
int input_row_begin, input_row_end;
int sequence_height, sequence_width;
......@@ -126,10 +127,7 @@ class ContextProjectFunctor {
{1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1, up_pad,
down_pad, 0, 0);
im2col_ocf(context, in_t, dilation, stride, padding, &out_t);
out_t.Resize({sequence_height, context_length * sequence_width});
}
}
......@@ -207,8 +205,9 @@ class ContextProjectGradFunctor {
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
int dilation_h = 1;
int dilation_w = 1;
std::vector<int> dilation({1, 1});
std::vector<int> padding({up_pad, 0, down_pad, 0});
std::vector<int> stride({context_stride, 1});
int input_row_begin, input_row_end;
int sequence_height, sequence_width;
......@@ -240,9 +239,7 @@ class ContextProjectGradFunctor {
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
col2im_ocf(context, out_t, dilation, stride, padding, &in_t);
out_t.Resize({sequence_height, context_length * sequence_width});
}
}
......
......@@ -28,40 +28,39 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
int col_width = col.dims()[4];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
((dilation[0] * (filter_height - 1) + 1))) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
((dilation[1] * (filter_width - 1) + 1))) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
......@@ -69,10 +68,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) {
int im_row_idx =
h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
......@@ -95,38 +92,39 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
void operator()(const platform::DeviceContext& context,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
((dilation[0] * (filter_height - 1) + 1))) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
((dilation[1] * (filter_width - 1) + 1))) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
int channels_col = im_channels * filter_height * filter_width;
T* im_data = im.data<T>();
T* im_data = im->data<T>();
const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) {
......@@ -135,10 +133,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) {
int im_row_idx =
h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
......@@ -171,35 +167,32 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int col_height = col.dims()[0];
int col_width = col.dims()[1];
int filter_height = col->dims()[3];
int filter_width = col->dims()[4];
int col_height = col->dims()[0];
int col_width = col->dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
T* col_data = col->data<T>();
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
......@@ -209,9 +202,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_up;
col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left;
col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset =
((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) *
......@@ -244,34 +237,33 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
void operator()(const platform::DeviceContext& context,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>();
T* im_data = im->data<T>();
const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
......@@ -282,9 +274,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_up;
col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left;
col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset =
(((col_row_idx * col_width + col_col_idx) * im_channels +
channel) *
......
......@@ -61,31 +61,30 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
......@@ -100,9 +99,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
im2col<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), num_outputs, im_height, im_width, dilation_h, dilation_w,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, col_height, col_width, col.data<T>());
im.data<T>(), num_outputs, im_height, im_width, dilation[0],
dilation[1], filter_height, filter_width, stride[0], stride[1],
padding[0], padding[1], col_height, col_width, col->data<T>());
}
};
......@@ -163,31 +162,32 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
void operator()(const platform::DeviceContext& context,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
......@@ -206,9 +206,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), im_height, im_width, dilation_h, dilation_w,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, col_height, col_width, im.data<T>());
num_kernels, col.data<T>(), im_height, im_width, dilation[0],
dilation[1], filter_height, filter_width, stride[0], stride[1],
padding[0], padding[2], col_height, col_width, im->data<T>());
}
};
......@@ -222,11 +222,11 @@ template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>;
template <class T>
__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels,
int im_height, int im_width, int filter_height,
int filter_width, int stride_height, int stride_width,
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
int im_width, int filter_height, int filter_width,
int stride_height, int stride_width,
int padding_height, int padding_width, int col_height,
int col_width) {
int col_width, T* col_data) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < im_channels;
......@@ -263,30 +263,29 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
int filter_height = col->dims()[3];
int filter_width = col->dims()[4];
int col_height = col->dims()[0];
int col_width = col->dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
......@@ -314,18 +313,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
im2colOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), im_channels, im_height, im_width,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, col_height, col_width);
im.data<T>(), im_channels, im_height, im_width, filter_height,
filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
col_width, col->data<T>());
}
};
template <class T>
__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels,
int im_height, int im_width, int filter_height,
int filter_width, int stride_height, int stride_width,
__global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
int im_width, int filter_height, int filter_width,
int stride_height, int stride_width,
int padding_height, int padding_width, int col_height,
int col_width) {
int col_width, T* im_data) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < im_channels;
......@@ -361,30 +360,31 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
void operator()(const platform::DeviceContext& context,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
......@@ -412,9 +412,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
col2imOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), im_channels, im_height, im_width,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, col_height, col_width);
col.data<T>(), im_channels, im_height, im_width, filter_height,
filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
col_width, im->data<T>());
}
};
......
......@@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
* \param colData Column data.
* \param colShape The shape of colData.
*
* \param dilations dilation data.
* \param 2-dimension [dilation_height, dilation_width].
*
* \param strides stride data.
* \param 2-dimension [stride_height, stride_width].
*
* \param paddings padding data.
* \param 4-dimension [up_pad, left_pad, down_pad, right_pad].
*
* If the template argument Format is kCFO, the shape of colData is:
* [input_channels, filter_height, filter_width, output_height, output_width]
* So, it is easy to reshape into a convolution matrix for convolution
......@@ -73,19 +82,19 @@ template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
const framework::Tensor& im, const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col);
};
template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right);
void operator()(const platform::DeviceContext& context,
const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im);
};
} // namespace math
......
......@@ -45,12 +45,14 @@ void testIm2col() {
int input_height = 2;
int input_width = 3;
int filter_size = 2;
int stride = 1;
int padding = 0;
int dilation_h = 1;
int dilation_w = 1;
int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
std::vector<int> stride({1, 1}); // stride_y, stride_x
std::vector<int> padding(
{0, 0, 0, 0}); // up_pad, left_pad, down_pad, right_pad
std::vector<int> dilation({1, 1}); // dilation_y, dilation_x
int output_height =
(input_height - filter_size + padding[0] + padding[1]) / stride[0] + 1;
int output_width =
(input_width - filter_size + padding[2] + padding[3]) / stride[1] + 1;
float* input_ptr = input_tmp.mutable_data<float>(
{1, input_height, input_width}, paddle::platform::CPUPlace());
float arr[6] = {0, 1, 2, 3, 4, 5};
......@@ -87,10 +89,8 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding, padding, padding, padding);
im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
stride, padding, padding, padding, padding);
im2col(*context, input, dilation, stride, padding, &output_cfo);
im2col_ocf(*context, input, dilation, stride, padding, &output_ocf);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
......@@ -133,8 +133,7 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context);
}
col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding, padding, padding, padding);
col2im(*context, output_cfo, dilation, stride, padding, &input);
float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......@@ -155,8 +154,7 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context);
}
col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
stride, padding, padding, padding, padding);
col2im_ocf(*context, output_ocf, dilation, stride, padding, &input);
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
......
......@@ -28,51 +28,51 @@ template <class T>
class Vol2ColFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
PADDLE_ENFORCE(col->dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
int output_depth = col->dims()[4];
int output_height = col->dims()[5];
int output_width = col->dims()[6];
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth -
((dilation_d * (filter_depth - 1) + 1))) /
stride_depth +
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"Mismatching.");
"mismatching.");
const T* vol_data = vol.data<T>();
T* col_data = col.data<T>();
T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
......@@ -80,13 +80,11 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth;
int c_in = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d;
int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) {
int h_pad =
h * stride_height - padding_height + h_offset * dilation_h;
int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) {
int w_pad =
w * stride_width - padding_width + w_offset * dilation_w;
int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
......@@ -116,18 +114,18 @@ template <class T>
class Col2VolFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const {
PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
int input_width = vol->dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
......@@ -137,28 +135,28 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth -
((dilation_d * (filter_depth - 1) + 1))) /
stride_depth +
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
"mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
"mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
"Mismatching.");
T* vol_data = vol.data<T>();
"mismatching.");
T* vol_data = vol->data<T>();
const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) {
......@@ -167,13 +165,11 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth;
int cIm = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d;
int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) {
int h_pad =
h * stride_height - padding_height + h_offset * dilation_h;
int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) {
int w_pad =
w * stride_width - padding_width + w_offset * dilation_w;
int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
......
......@@ -71,42 +71,42 @@ template <class T>
class Vol2ColFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
PADDLE_ENFORCE(col->dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
int output_depth = col->dims()[4];
int output_height = col->dims()[5];
int output_width = col->dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth -
((dilation_d * (filter_depth - 1) + 1))) /
stride_depth +
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
......@@ -121,10 +121,10 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilation_d, dilation_h, dilation_w, filter_depth, filter_height,
filter_width, stride_depth, stride_height, stride_width, padding_depth,
padding_height, padding_width, output_depth, output_height,
output_width, col.data<T>());
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], paddings[0],
paddings[1], paddings[2], output_depth, output_height, output_width,
col->data<T>());
}
};
......@@ -200,18 +200,18 @@ template <class T>
class Col2VolFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const {
PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int input_channels = vol->dims()[0];
int input_depth = vol->dims()[1];
int input_height = vol->dims()[2];
int input_width = vol->dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
......@@ -219,23 +219,23 @@ class Col2VolFunctor<platform::GPUPlace, T> {
int output_height = col.dims()[5];
int output_width = col.dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth -
((dilation_d * (filter_depth - 1) + 1))) /
stride_depth +
PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1,
output_depth,
"input_depth and output_depth are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1,
output_height,
"input_height and output_height are "
"Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1,
output_width,
"input_width and output_width are "
......@@ -250,10 +250,10 @@ class Col2VolFunctor<platform::GPUPlace, T> {
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilation_d, dilation_h, dilation_w, filter_depth, filter_height,
filter_width, stride_depth, stride_height, stride_width, padding_depth,
padding_height, padding_width, output_depth, output_height,
output_width, vol.data<T>());
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], paddings[0],
paddings[1], paddings[2], output_depth, output_height, output_width,
vol->data<T>());
}
};
......
......@@ -31,6 +31,15 @@ namespace math {
* \param colData Column data.
* \param colShape The shape of colData.
*
* \param dilations dilation data.
* \param 3-dimension [dilation_depth, dilation_height, dilation_width].
*
* \param strides stride data.
* \param 3-dimension [stride_depth, stride_height, stride_width].
*
* \param paddings padding data.
* \param 3-dimension [d_pad, h_pad, w_pad].
*
* The shape of colData is:
* [input_channels, filter_depth, filter_height, filter_width, output_depth,
* output_height, output_width]
......@@ -57,22 +66,22 @@ template <typename Place, typename T>
class Vol2ColFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const;
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* col) const;
};
template <typename Place, typename T>
class Col2VolFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const;
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* vol) const;
};
} // namespace math
......
......@@ -62,12 +62,15 @@ void testVol2col() {
int input_height = 2;
int input_width = 3;
int filter_size = 2;
int stride = 1;
int padding = 0;
int dilation = 1;
int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1;
int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
std::vector<int> strides({1, 1, 1});
std::vector<int> paddings({0, 0, 0});
std::vector<int> dilations({1, 1, 1});
int output_depth =
(input_depth - filter_size + 2 * paddings[0]) / strides[0] + 1;
int output_height =
(input_height - filter_size + 2 * paddings[1]) / strides[1] + 1;
int output_width =
(input_width - filter_size + 2 * paddings[2]) / strides[2] + 1;
// Vol2Col test
float* input_ptr =
......@@ -86,8 +89,7 @@ void testVol2col() {
*place);
paddle::operators::math::Vol2ColFunctor<Place, float> vol2col;
vol2col(*context, input, output, dilation, dilation, dilation, stride, stride,
stride, padding, padding, padding);
vol2col(*context, input, dilations, strides, paddings, &output);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
float* out_cfo_ptr;
......@@ -112,8 +114,7 @@ void testVol2col() {
}
paddle::operators::math::Col2VolFunctor<Place, float> col2vol;
col2vol(*context, input, output, dilation, dilation, dilation, stride, stride,
stride, padding, padding, padding);
col2vol(*context, output, dilations, strides, paddings, &input);
float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册