提交 356d6954 编写于 作者: C chengduoZH

follow comments

上级 7d73b8fc
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/conv_op.h" #include "paddle/operators/conv_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,7 +54,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -53,7 +54,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"The number of output channels should be divided by groups."); "The number of output channels should be divided by groups.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] - PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] -
(dilations[i] * (filter_dims[i + 2] - 1) + 1) > (dilations[i] * (filter_dims[i + 2] - 1) + 1) >
0, 0,
...@@ -61,8 +62,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -61,8 +62,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"dilations, the output size is less than 0, please check " "dilations, the output size is less than 0, please check "
"again."); "again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], paddings[i], dilations[i], paddings[i], strides[i]));
strides[i]));
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
...@@ -86,9 +86,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, ...@@ -86,9 +86,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", "strides of convolution operator.") AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.") AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0, 0}), the "
"paddings(h_pad, w_pad) of "
"convolution operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<int>( AddAttr<int>(
"groups", "groups",
...@@ -99,9 +105,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, ...@@ -99,9 +105,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations", AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1}), the dilations of " "(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator.") "convolution operator.")
.SetDefault(std::vector<int>{1, 1}); .SetDefault({1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution Operator. Convolution Operator.
...@@ -147,13 +154,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, ...@@ -147,13 +154,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator." "(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides", "(vector<int>, default:{1, 1, 1}), the "
"(vector, default:{0, 0, 0}), the strides of convolution operator.") "strides(d_stride, h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("paddings",
"paddings", "(vector<int>, default:{0, 0, 0}), the "
"(vector, default:{0, 0, 0}), the paddings of convolution operator.") "paddings(d_pad, h_pad, w_pad) of convolution "
"operator.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0});
AddAttr<int>( AddAttr<int>(
"groups", "groups",
...@@ -164,10 +173,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, ...@@ -164,10 +173,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations", AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1, 1}), the dilations of " "(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation, h_dilation, w_dilation) of "
"convolution operator. Currently, conv3d doesn't " "convolution operator. Currently, conv3d doesn't "
"support dilation.") "support dilation.")
.SetDefault(std::vector<int>{1, 1, 1}); .SetDefault({1, 1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution3D Operator. Convolution3D Operator.
......
...@@ -28,24 +28,22 @@ using Tensor = framework::Tensor; ...@@ -28,24 +28,22 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv // Base convolution operator definations for other conv
// like operators to reuse the implementation. // like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int dilation, inline int OutputSize(int input_size, int filter_size, int dilation,
int padding_up, int padding_down, int stride) { int padding, int stride) {
int output_size = (input_size + padding_up + padding_down - const int dkernel = dilation * (filter_size - 1) + 1;
(dilation * (filter_size - 1) + 1)) / const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
stride +
1;
return output_size; return output_size;
} }
inline bool NotExpand(std::vector<int64_t>& filter_dim, inline bool IsExpand(std::vector<int64_t>& filter_dim,
std::vector<int>& strides, std::vector<int>& paddings, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& dilations) { std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) { for (size_t j = 0; j < strides.size(); ++j) {
filter_1 &= (static_cast<int>(filter_dim[j]) == 1); filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
strides_1 &= (strides[j] == 1); strides_1 = strides_1 && (strides[j] == 1);
padding_0 &= (paddings[j] == 0); padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 &= (dilations[j] == 1); dilation_1 = dilation_1 && (dilations[j] == 1);
} }
return filter_1 && strides_1 && padding_0 && dilation_1; return !(filter_1 && strides_1 && padding_0 && dilation_1);
} }
// Define Op classes in .h file so that other conv // Define Op classes in .h file so that other conv
...@@ -65,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvOp : public framework::OperatorWithKernel { class ConvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
class ConvOpGrad : public framework::OperatorWithKernel { class ConvOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
...@@ -88,9 +84,9 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -88,9 +84,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -122,13 +118,13 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -122,13 +118,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
// col_matrix shares the same piece of data with col, // col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape // but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
if (!not_expand) { if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
...@@ -149,51 +145,37 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -149,51 +145,37 @@ class GemmConvKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups; int out_step = static_cast<int>(output->dims()[1]) / groups;
if (!not_expand) { math::Vol2ColFunctor<Place, T> vol2col;
for (int i = 0; i < batch_size; i++) { math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) { for (int i = 0; i < batch_size; i++) {
// im2col Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
im2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
// vol2col
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], dilations[2], strides[0], strides[1],
strides[2], paddings[0], paddings[1], paddings[2]);
}
// gemm for (int g = 0; g < groups; g++) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
}
}
} else {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice); col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (filter_shape_vec.size() == 2) {
// gemm // im2col
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); im2col(context.device_context(), in_slice, dilations, strides,
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); std::vector<int>{paddings[0], paddings[1], paddings[0],
math::matmul<Place, T>(context.device_context(), filter_slice, false, paddings[1]},
col_matrix, false, T(1.0), &out_slice, T(0.0)); &col);
} else if (filter_shape_vec.size() == 3) {
// vol2col
vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
} }
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
} }
} }
} }
...@@ -217,9 +199,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -217,9 +199,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (!input_grad && !filter_grad) return; if (!input_grad && !filter_grad) return;
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -270,13 +252,13 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -270,13 +252,13 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output_grad->dims()[1]) / groups; int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
// col_matrix shares the same piece of data with col, // col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape // but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
if (!not_expand) { if (is_expand) {
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
...@@ -288,61 +270,38 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -288,61 +270,38 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0)); set_zero(context.device_context(), input_grad, static_cast<T>(0));
if (!not_expand) { math::Col2VolFunctor<Place, T> col2vol;
for (int i = 0; i < batch_size; i++) { math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice =
filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), in_grad_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Col2VolFunctor<Place, T> col2vol;
col2vol(context.device_context(), in_grad_slice, col,
dilations[0], dilations[1], dilations[2], strides[0],
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
}
}
}
} else {
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice =
filter.Slice(g * out_step, (g + 1) * out_step);
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col_matrix.ShareDataWith(in_grad_slice); col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
}
math::matmul<Place, T>(context.device_context(), filter_slice, true, math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix, out_grad_slice, false, T(1.0), &col_matrix,
T(0.0)); T(0.0));
if (is_expand && filter_shape_vec.size() == 2) {
col2im(context.device_context(), col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&in_grad_slice);
} else if (is_expand && filter_shape_vec.size() == 3) {
col2vol(context.device_context(), col, dilations, strides, paddings,
&in_grad_slice);
} }
} }
} }
...@@ -353,60 +312,38 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -353,60 +312,38 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
Tensor filter_grad_ = *filter_grad; Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
set_zero(context.device_context(), filter_grad, static_cast<T>(0)); set_zero(context.device_context(), filter_grad, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!not_expand) { if (!is_expand) {
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (filter_shape_vec.size() == 2) {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], dilations[2], strides[0], strides[1],
strides[2], paddings[0], paddings[1], paddings[2]);
}
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
} else {
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
col.ShareDataWith(in_slice); col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (filter_shape_vec.size() == 2) {
// gemm im2col(context.device_context(), in_slice, dilations, strides,
Tensor filter_grad_slice = std::vector<int>{paddings[0], paddings[1], paddings[0],
filter_grad_.Slice(g * out_step, (g + 1) * out_step); paddings[1]},
math::matmul<Place, T>(context.device_context(), out_grad_slice, &col);
false, col_matrix, true, T(1.0), } else if (filter_shape_vec.size() == 3) {
&filter_grad_slice, T(1.0)); vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
} }
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
} }
} }
} }
......
...@@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"as the number of filters."); "as the number of filters.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] + output_shape.push_back((in_dims[i + 2] - 1) * strides[i] +
filter_dims[i + 2]); filter_dims[i + 2]);
} }
...@@ -77,13 +77,14 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( ...@@ -77,13 +77,14 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator. " "(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides", "(vector<int> defalut:{1, 1}), strides of "
"(vector defalut:{1, 1}), strides of convolution transpose operator.") "convolution transpose operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector defalut:{0, 0}), paddings of convolution transpose operator.") "(vector<int> defalut:{0, 0}), paddings(h_pad, w_pad) of convolution "
"transpose operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddComment(R"DOC( AddComment(R"DOC(
Convolution2D Transpose Operator. Convolution2D Transpose Operator.
...@@ -132,13 +133,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( ...@@ -132,13 +133,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the " "the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature."); "height of the feature, and W is the width of the feature.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides", "(vector<int> defalut:{1, 1, 1}), strides of "
"(vector defalut:{1, 1, 1}), strides of convolution transpose operator.") "convolution transpose operator.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("paddings",
"paddings", "(vector<int> defalut:{0, 0, 0}), paddings(d_pad, "
"(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.") "h_pad, w_pad) of convolution transpose operator.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0});
AddComment(R"DOC( AddComment(R"DOC(
Convolution3D Transpose Operator. Convolution3D Transpose Operator.
......
...@@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvTransposeOp : public framework::OperatorWithKernel { class ConvTransposeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
class ConvTransposeOpGrad : public framework::OperatorWithKernel { class ConvTransposeOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
...@@ -66,13 +62,11 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -66,13 +62,11 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
// TODO(Zhuoyuan): Paddings can be added in future. // TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose. // groups will alway be disabled in conv2dtranspose.
int dilaiton_d = 1;
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {h, w} or {d, h, w}
...@@ -124,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -124,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), output, static_cast<T>(0)); set_zero(context.device_context(), output, static_cast<T>(0));
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Col2VolFunctor<Place, T> col2vol;
std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input) // on input)
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
...@@ -142,17 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -142,17 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
// col2im: col_matrix -> dy // col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; col2im(context.device_context(), col,
std::vector<int>{dilations[0], dilations[1]}, strides,
col2im(context.device_context(), output_batch, col, dilation_h, std::vector<int>{paddings[0], paddings[1], paddings[0],
dilation_w, strides[0], strides[1], 0, 0, 0, 0); paddings[1]},
&output_batch);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// col2vol: col_matrix -> dy // col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
math::Col2VolFunctor<Place, T> col2vol; col2vol(context.device_context(), col, dilations, strides,
col2vol(context.device_context(), output_batch, col, dilaiton_d, std::vector<int>{0, 0, 0}, &output_batch);
dilation_h, dilation_w, strides[0], strides[1], strides[2], 0,
0, 0);
} }
} }
} }
...@@ -179,10 +176,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -179,10 +176,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose. // Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int dilaiton_d = 1;
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {h, w} or {d, h, w}
...@@ -237,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -237,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
Tensor filter_grad_; Tensor filter_grad_;
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
std::vector<int> dilations({1, 1, 1});
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0)); set_zero(context.device_context(), input_grad, static_cast<T>(0));
...@@ -256,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -256,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
// im2col: dy -> col matrix // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; im2col(context.device_context(), output_grad_batch,
im2col(context.device_context(), output_grad_batch, col, dilation_h, std::vector<int>{dilations[0], dilations[1]}, strides,
dilation_w, strides[0], strides[1], paddings[0], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1], paddings[1]); paddings[1]},
&col);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// vol2col: dy -> col_matrix // vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
math::Vol2ColFunctor<Place, T> vol2col; vol2col(context.device_context(), output_grad_batch, dilations,
vol2col(context.device_context(), output_grad_batch, col, dilaiton_d, strides, paddings, &col);
dilation_h, dilation_w, strides[0], strides[1], strides[2],
paddings[0], paddings[1], paddings[2]);
} }
if (input_grad) { if (input_grad) {
......
...@@ -95,8 +95,9 @@ class ContextProjectFunctor { ...@@ -95,8 +95,9 @@ class ContextProjectFunctor {
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf; math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
int dilation_h = 1; std::vector<int> dilation({1, 1});
int dilation_w = 1; std::vector<int> padding({up_pad, 0, down_pad, 0});
std::vector<int> stride({context_stride, 1});
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
...@@ -126,10 +127,7 @@ class ContextProjectFunctor { ...@@ -126,10 +127,7 @@ class ContextProjectFunctor {
{1, input_row_end - input_row_begin, {1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context, in_t, dilation, stride, padding, &out_t);
im2col_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1, up_pad,
down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
} }
} }
...@@ -207,8 +205,9 @@ class ContextProjectGradFunctor { ...@@ -207,8 +205,9 @@ class ContextProjectGradFunctor {
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf; math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
int dilation_h = 1; std::vector<int> dilation({1, 1});
int dilation_w = 1; std::vector<int> padding({up_pad, 0, down_pad, 0});
std::vector<int> stride({context_stride, 1});
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
...@@ -240,9 +239,7 @@ class ContextProjectGradFunctor { ...@@ -240,9 +239,7 @@ class ContextProjectGradFunctor {
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context, in_t, out_t, dilation_h, dilation_w, col2im_ocf(context, out_t, dilation, stride, padding, &in_t);
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
} }
} }
......
...@@ -28,40 +28,39 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -28,40 +28,39 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int dilation_h, int dilation_w, int stride_height, const std::vector<int>& stride,
int stride_width, int padding_up, int padding_down, const std::vector<int>& padding, framework::Tensor* col) {
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im.dims()[0];
int im_height = im.dims()[1]; int im_height = im.dims()[1];
int im_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col->dims()[1];
int filter_width = col.dims()[2]; int filter_width = col->dims()[2];
int col_height = col.dims()[3]; int col_height = col->dims()[3];
int col_width = col.dims()[4]; int col_width = col->dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
((dilation_h * (filter_height - 1) + 1))) / ((dilation[0] * (filter_height - 1) + 1))) /
stride_height + stride[0] +
1, 1,
col_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
((dilation_w * (filter_width - 1) + 1))) / ((dilation[1] * (filter_width - 1) + 1))) /
stride_width + stride[1] +
1, 1,
col_width, col_width,
"col_width and padding(padding_left, padding_right) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
...@@ -69,10 +68,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -69,10 +68,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < col_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
h * stride_height - padding_up + h_offset * dilation_h; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
int col_idx = (c * col_height + h) * col_width + w; int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
...@@ -95,38 +92,39 @@ template <class T> ...@@ -95,38 +92,39 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int dilation_h, int dilation_w, const framework::Tensor& col,
int stride_height, int stride_width, int padding_up, const std::vector<int>& dilation,
int padding_down, int padding_left, int padding_right) { const std::vector<int>& stride,
PADDLE_ENFORCE(im.dims().size() == 3); const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im->dims()[0];
int im_height = im.dims()[1]; int im_height = im->dims()[1];
int im_width = im.dims()[2]; int im_width = im->dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int col_height = col.dims()[3]; int col_height = col.dims()[3];
int col_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
((dilation_h * (filter_height - 1) + 1))) / ((dilation[0] * (filter_height - 1) + 1))) /
stride_height + stride[0] +
1, 1,
col_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
((dilation_w * (filter_width - 1) + 1))) / ((dilation[1] * (filter_width - 1) + 1))) /
stride_width + stride[1] +
1, 1,
col_width, col_width,
"col_width and padding(padding_left, padding_right) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
T* im_data = im.data<T>(); T* im_data = im->data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
...@@ -135,10 +133,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -135,10 +133,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < col_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
h * stride_height - padding_up + h_offset * dilation_h; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
if ((im_row_idx) >= 0 && (im_row_idx) < im_height && if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) { (im_col_idx) >= 0 && (im_col_idx) < im_width) {
...@@ -171,35 +167,32 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -171,35 +167,32 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int dilation_h, int dilation_w, int stride_height, const std::vector<int>& stride,
int stride_width, int padding_up, int padding_down, const std::vector<int>& padding, framework::Tensor* col) {
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im.dims()[0];
int im_height = im.dims()[1]; int im_height = im.dims()[1];
int im_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col->dims()[3];
int filter_width = col.dims()[4]; int filter_width = col->dims()[4];
int col_height = col.dims()[0]; int col_height = col->dims()[0];
int col_width = col.dims()[1]; int col_width = col->dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ(
stride_height + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
1, col_height,
col_height, "Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are " "inconsistent.");
"inconsistent."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
col_width, "inconsistent.");
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col->data<T>();
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
...@@ -209,9 +202,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -209,9 +202,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset = int col_offset =
((((col_row_idx)*col_width + col_col_idx) * im_channels + ((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) * channel) *
...@@ -244,34 +237,33 @@ template <class T> ...@@ -244,34 +237,33 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int dilation_h, int dilation_w, const framework::Tensor& col,
int stride_height, int stride_width, int padding_up, const std::vector<int>& dilation,
int padding_down, int padding_left, int padding_right) { const std::vector<int>& stride,
PADDLE_ENFORCE(im.dims().size() == 3); const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im->dims()[0];
int im_height = im.dims()[1]; int im_height = im->dims()[1];
int im_width = im.dims()[2]; int im_width = im->dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int col_height = col.dims()[0]; int col_height = col.dims()[0];
int col_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ(
stride_height + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
1, col_height,
col_height, "Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are " "inconsistent.");
"inconsistent."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
col_width, "inconsistent.");
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>(); T* im_data = im->data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
...@@ -282,9 +274,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -282,9 +274,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset = int col_offset =
(((col_row_idx * col_width + col_col_idx) * im_channels + (((col_row_idx * col_width + col_col_idx) * im_channels +
channel) * channel) *
......
...@@ -61,31 +61,30 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -61,31 +61,30 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int dilation_h, int dilation_w, int stride_height, const std::vector<int>& stride,
int stride_width, int padding_up, int padding_down, const std::vector<int>& padding, framework::Tensor* col) {
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im.dims()[0];
int im_height = im.dims()[1]; int im_height = im.dims()[1];
int im_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col->dims()[1];
int filter_width = col.dims()[2]; int filter_width = col->dims()[2];
int col_height = col.dims()[3]; int col_height = col->dims()[3];
int col_width = col.dims()[4]; int col_width = col->dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation_h * (filter_height - 1) + 1)) / (dilation[0] * (filter_height - 1) + 1)) /
stride_height + stride[0] +
1, 1,
col_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation_w * (filter_width - 1) + 1)) / (dilation[1] * (filter_width - 1) + 1)) /
stride_width + stride[1] +
1, 1,
col_width, col_width,
"col_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
...@@ -100,9 +99,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -100,9 +99,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
im2col<T><<<grid, threads, 0, im2col<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), num_outputs, im_height, im_width, dilation_h, dilation_w, im.data<T>(), num_outputs, im_height, im_width, dilation[0],
filter_height, filter_width, stride_height, stride_width, padding_up, dilation[1], filter_height, filter_width, stride[0], stride[1],
padding_left, col_height, col_width, col.data<T>()); padding[0], padding[1], col_height, col_width, col->data<T>());
} }
}; };
...@@ -163,31 +162,32 @@ template <class T> ...@@ -163,31 +162,32 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int dilation_h, int dilation_w, const framework::Tensor& col,
int stride_height, int stride_width, int padding_up, const std::vector<int>& dilation,
int padding_down, int padding_left, int padding_right) { const std::vector<int>& stride,
PADDLE_ENFORCE(im.dims().size() == 3); const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im->dims()[0];
int im_height = im.dims()[1]; int im_height = im->dims()[1];
int im_width = im.dims()[2]; int im_width = im->dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int col_height = col.dims()[3]; int col_height = col.dims()[3];
int col_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation_h * (filter_height - 1) + 1)) / (dilation[0] * (filter_height - 1) + 1)) /
stride_height + stride[0] +
1, 1,
col_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation_w * (filter_width - 1) + 1)) / (dilation[1] * (filter_width - 1) + 1)) /
stride_width + stride[1] +
1, 1,
col_width, col_width,
"col_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
...@@ -206,9 +206,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -206,9 +206,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0, col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
num_kernels, col.data<T>(), im_height, im_width, dilation_h, dilation_w, num_kernels, col.data<T>(), im_height, im_width, dilation[0],
filter_height, filter_width, stride_height, stride_width, padding_up, dilation[1], filter_height, filter_width, stride[0], stride[1],
padding_left, col_height, col_width, im.data<T>()); padding[0], padding[2], col_height, col_width, im->data<T>());
} }
}; };
...@@ -222,11 +222,11 @@ template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -222,11 +222,11 @@ template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>; platform::GPUPlace, double>;
template <class T> template <class T>
__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels, __global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
int im_height, int im_width, int filter_height, int im_width, int filter_height, int filter_width,
int filter_width, int stride_height, int stride_width, int stride_height, int stride_width,
int padding_height, int padding_width, int col_height, int padding_height, int padding_width, int col_height,
int col_width) { int col_width, T* col_data) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < im_channels; for (int channelid = threadIdx.z; channelid < im_channels;
...@@ -263,30 +263,29 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -263,30 +263,29 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int dilation_h, int dilation_w, int stride_height, const std::vector<int>& stride,
int stride_width, int padding_up, int padding_down, const std::vector<int>& padding, framework::Tensor* col) {
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col->dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im.dims()[0];
int im_height = im.dims()[1]; int im_height = im.dims()[1];
int im_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col->dims()[3];
int filter_width = col.dims()[4]; int filter_width = col->dims()[4];
int col_height = col.dims()[0]; int col_height = col->dims()[0];
int col_width = col.dims()[1]; int col_width = col->dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation_h * (filter_height - 1) + 1)) / (dilation[0] * (filter_height - 1) + 1)) /
stride_height + stride[0] +
1, 1,
col_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation_w * (filter_width - 1) + 1)) / (dilation[1] * (filter_width - 1) + 1)) /
stride_width + stride[1] +
1, 1,
col_width, col_width,
"col_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
...@@ -314,18 +313,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -314,18 +313,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
im2colOCF<T><<<grid, threads, 0, im2colOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), im_channels, im_height, im_width, im.data<T>(), im_channels, im_height, im_width, filter_height,
filter_height, filter_width, stride_height, stride_width, padding_up, filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
padding_left, col_height, col_width); col_width, col->data<T>());
} }
}; };
template <class T> template <class T>
__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels, __global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
int im_height, int im_width, int filter_height, int im_width, int filter_height, int filter_width,
int filter_width, int stride_height, int stride_width, int stride_height, int stride_width,
int padding_height, int padding_width, int col_height, int padding_height, int padding_width, int col_height,
int col_width) { int col_width, T* im_data) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < im_channels; for (int channelid = threadIdx.z; channelid < im_channels;
...@@ -361,30 +360,31 @@ template <class T> ...@@ -361,30 +360,31 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int dilation_h, int dilation_w, const framework::Tensor& col,
int stride_height, int stride_width, int padding_up, const std::vector<int>& dilation,
int padding_down, int padding_left, int padding_right) { const std::vector<int>& stride,
PADDLE_ENFORCE(im.dims().size() == 3); const std::vector<int>& padding, framework::Tensor* im) {
PADDLE_ENFORCE(im->dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im.dims()[0]; int im_channels = im->dims()[0];
int im_height = im.dims()[1]; int im_height = im->dims()[1];
int im_width = im.dims()[2]; int im_width = im->dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int col_height = col.dims()[0]; int col_height = col.dims()[0];
int col_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation_h * (filter_height - 1) + 1)) / (dilation[0] * (filter_height - 1) + 1)) /
stride_height + stride[0] +
1, 1,
col_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation_w * (filter_width - 1) + 1)) / (dilation[1] * (filter_width - 1) + 1)) /
stride_width + stride[1] +
1, 1,
col_width, col_width,
"col_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
...@@ -412,9 +412,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -412,9 +412,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
col2imOCF<T><<<grid, threads, 0, col2imOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), im_channels, im_height, im_width, col.data<T>(), im_channels, im_height, im_width, filter_height,
filter_height, filter_width, stride_height, stride_width, padding_up, filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
padding_left, col_height, col_width); col_width, im->data<T>());
} }
}; };
......
...@@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; ...@@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
* \param colData Column data. * \param colData Column data.
* \param colShape The shape of colData. * \param colShape The shape of colData.
* *
* \param dilations dilation data.
* \param 2-dimension [dilation_height, dilation_width].
*
* \param strides stride data.
* \param 2-dimension [stride_height, stride_width].
*
* \param paddings padding data.
* \param 4-dimension [up_pad, left_pad, down_pad, right_pad].
*
* If the template argument Format is kCFO, the shape of colData is: * If the template argument Format is kCFO, the shape of colData is:
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
* So, it is easy to reshape into a convolution matrix for convolution * So, it is easy to reshape into a convolution matrix for convolution
...@@ -73,19 +82,19 @@ template <ColFormat Format, typename Place, typename T> ...@@ -73,19 +82,19 @@ template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor { class Im2ColFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, const std::vector<int>& dilation,
int dilation_h, int dilation_w, int stride_height, const std::vector<int>& stride,
int stride_width, int padding_up, int padding_down, const std::vector<int>& padding, framework::Tensor* col);
int padding_left, int padding_right);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context,
const framework::Tensor& col, int dilation_h, int dilation_w, const framework::Tensor& col,
int stride_height, int stride_width, int padding_up, const std::vector<int>& dilation,
int padding_down, int padding_left, int padding_right); const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im);
}; };
} // namespace math } // namespace math
......
...@@ -45,12 +45,14 @@ void testIm2col() { ...@@ -45,12 +45,14 @@ void testIm2col() {
int input_height = 2; int input_height = 2;
int input_width = 3; int input_width = 3;
int filter_size = 2; int filter_size = 2;
int stride = 1; std::vector<int> stride({1, 1}); // stride_y, stride_x
int padding = 0; std::vector<int> padding(
int dilation_h = 1; {0, 0, 0, 0}); // up_pad, left_pad, down_pad, right_pad
int dilation_w = 1; std::vector<int> dilation({1, 1}); // dilation_y, dilation_x
int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_height =
int output_width = (input_width - filter_size + 2 * padding) / stride + 1; (input_height - filter_size + padding[0] + padding[1]) / stride[0] + 1;
int output_width =
(input_width - filter_size + padding[2] + padding[3]) / stride[1] + 1;
float* input_ptr = input_tmp.mutable_data<float>( float* input_ptr = input_tmp.mutable_data<float>(
{1, input_height, input_width}, paddle::platform::CPUPlace()); {1, input_height, input_width}, paddle::platform::CPUPlace());
float arr[6] = {0, 1, 2, 3, 4, 5}; float arr[6] = {0, 1, 2, 3, 4, 5};
...@@ -87,10 +89,8 @@ void testIm2col() { ...@@ -87,10 +89,8 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float> paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf; im2col_ocf;
im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, im2col(*context, input, dilation, stride, padding, &output_cfo);
padding, padding, padding, padding); im2col_ocf(*context, input, dilation, stride, padding, &output_ocf);
im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
stride, padding, padding, padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
...@@ -133,8 +133,7 @@ void testIm2col() { ...@@ -133,8 +133,7 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, col2im(*context, output_cfo, dilation, stride, padding, &input);
padding, padding, padding, padding);
float* in_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
...@@ -155,8 +154,7 @@ void testIm2col() { ...@@ -155,8 +154,7 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, col2im_ocf(*context, output_ocf, dilation, stride, padding, &input);
stride, padding, padding, padding, padding);
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
......
...@@ -28,51 +28,51 @@ template <class T> ...@@ -28,51 +28,51 @@ template <class T>
class Vol2ColFunctor<platform::CPUPlace, T> { class Vol2ColFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col, const framework::Tensor& vol,
int dilation_d, int dilation_h, int dilation_w, const std::vector<int>& dilations,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& strides,
int padding_depth, int padding_height, const std::vector<int>& paddings,
int padding_width) const { framework::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4); PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col->dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol.dims()[1];
int input_height = vol.dims()[2]; int input_height = vol.dims()[2];
int input_width = vol.dims()[3]; int input_width = vol.dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col->dims()[1];
int filter_height = col.dims()[2]; int filter_height = col->dims()[2];
int filter_width = col.dims()[3]; int filter_width = col->dims()[3];
int output_depth = col.dims()[4]; int output_depth = col->dims()[4];
int output_height = col.dims()[5]; int output_height = col->dims()[5];
int output_width = col.dims()[6]; int output_width = col->dims()[6];
int channels_col = int channels_col =
input_channels * filter_depth * filter_height * filter_width; input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilation_d * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
stride_depth + strides[0] +
1, 1,
output_depth, output_depth,
"input_depth and output_depth are " "input_depth and output_depth are "
"Mismatching."); "mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilation_h * (filter_height - 1) + 1))) / ((dilations[1] * (filter_height - 1) + 1))) /
stride_height + strides[1] +
1, 1,
output_height, output_height,
"input_height and output_height are " "input_height and output_height are "
"Mismatching."); "mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilation_w * (filter_width - 1) + 1))) / ((dilations[2] * (filter_width - 1) + 1))) /
stride_width + strides[2] +
1, 1,
output_width, output_width,
"input_width and output_width are " "input_width and output_width are "
"Mismatching."); "mismatching.");
const T* vol_data = vol.data<T>(); const T* vol_data = vol.data<T>();
T* col_data = col.data<T>(); T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
...@@ -80,13 +80,11 @@ class Vol2ColFunctor<platform::CPUPlace, T> { ...@@ -80,13 +80,11 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth; int d_offset = (c / filter_width / filter_height) % filter_depth;
int c_in = c / filter_width / filter_height / filter_depth; int c_in = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) { for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < output_height; ++h) {
int h_pad = int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
h * stride_height - padding_height + h_offset * dilation_h;
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int w_pad = int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
w * stride_width - padding_width + w_offset * dilation_w;
int col_idx = int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w; ((c * output_depth + d) * output_height + h) * output_width + w;
...@@ -116,18 +114,18 @@ template <class T> ...@@ -116,18 +114,18 @@ template <class T>
class Col2VolFunctor<platform::CPUPlace, T> { class Col2VolFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col, const framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w, const std::vector<int>& dilations,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& strides,
int padding_depth, int padding_height, const std::vector<int>& paddings,
int padding_width) const { framework::Tensor* vol) const {
PADDLE_ENFORCE(vol.dims().size() == 4); PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol->dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol->dims()[1];
int input_height = vol.dims()[2]; int input_height = vol->dims()[2];
int input_width = vol.dims()[3]; int input_width = vol->dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col.dims()[1];
int filter_height = col.dims()[2]; int filter_height = col.dims()[2];
int filter_width = col.dims()[3]; int filter_width = col.dims()[3];
...@@ -137,28 +135,28 @@ class Col2VolFunctor<platform::CPUPlace, T> { ...@@ -137,28 +135,28 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int channels_col = int channels_col =
input_channels * filter_depth * filter_height * filter_width; input_channels * filter_depth * filter_height * filter_width;
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilation_d * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
stride_depth + strides[0] +
1, 1,
output_depth, output_depth,
"input_depth and output_depth are " "input_depth and output_depth are "
"Mismatching."); "mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilation_h * (filter_height - 1) + 1))) / ((dilations[1] * (filter_height - 1) + 1))) /
stride_height + strides[1] +
1, 1,
output_height, output_height,
"input_height and output_height are " "input_height and output_height are "
"Mismatching."); "mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilation_w * (filter_width - 1) + 1))) / ((dilations[2] * (filter_width - 1) + 1))) /
stride_width + strides[2] +
1, 1,
output_width, output_width,
"input_width and output_width are " "input_width and output_width are "
"Mismatching."); "mismatching.");
T* vol_data = vol.data<T>(); T* vol_data = vol->data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
...@@ -167,13 +165,11 @@ class Col2VolFunctor<platform::CPUPlace, T> { ...@@ -167,13 +165,11 @@ class Col2VolFunctor<platform::CPUPlace, T> {
int d_offset = (c / filter_width / filter_height) % filter_depth; int d_offset = (c / filter_width / filter_height) % filter_depth;
int cIm = c / filter_width / filter_height / filter_depth; int cIm = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) { for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < output_height; ++h) {
int h_pad = int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1];
h * stride_height - padding_height + h_offset * dilation_h;
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int w_pad = int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2];
w * stride_width - padding_width + w_offset * dilation_w;
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
......
...@@ -71,42 +71,42 @@ template <class T> ...@@ -71,42 +71,42 @@ template <class T>
class Vol2ColFunctor<platform::GPUPlace, T> { class Vol2ColFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col, const framework::Tensor& vol,
int dilation_d, int dilation_h, int dilation_w, const std::vector<int>& dilations,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& strides,
int padding_depth, int padding_height, const std::vector<int>& paddings,
int padding_width) const { framework::Tensor* col) const {
PADDLE_ENFORCE(vol.dims().size() == 4); PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col->dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol.dims()[1];
int input_height = vol.dims()[2]; int input_height = vol.dims()[2];
int input_width = vol.dims()[3]; int input_width = vol.dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col->dims()[1];
int filter_height = col.dims()[2]; int filter_height = col->dims()[2];
int filter_width = col.dims()[3]; int filter_width = col->dims()[3];
int output_depth = col.dims()[4]; int output_depth = col->dims()[4];
int output_height = col.dims()[5]; int output_height = col->dims()[5];
int output_width = col.dims()[6]; int output_width = col->dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilation_d * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
stride_depth + strides[0] +
1, 1,
output_depth, output_depth,
"input_depth and output_depth are " "input_depth and output_depth are "
"Mismatching."); "Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilation_h * (filter_height - 1) + 1))) / ((dilations[1] * (filter_height - 1) + 1))) /
stride_height + strides[1] +
1, 1,
output_height, output_height,
"input_height and output_height are " "input_height and output_height are "
"Mismatching."); "Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilation_w * (filter_width - 1) + 1))) / ((dilations[2] * (filter_width - 1) + 1))) /
stride_width + strides[2] +
1, 1,
output_width, output_width,
"input_width and output_width are " "input_width and output_width are "
...@@ -121,10 +121,10 @@ class Vol2ColFunctor<platform::GPUPlace, T> { ...@@ -121,10 +121,10 @@ class Vol2ColFunctor<platform::GPUPlace, T> {
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width, num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilation_d, dilation_h, dilation_w, filter_depth, filter_height, dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, stride_depth, stride_height, stride_width, padding_depth, filter_width, strides[0], strides[1], strides[2], paddings[0],
padding_height, padding_width, output_depth, output_height, paddings[1], paddings[2], output_depth, output_height, output_width,
output_width, col.data<T>()); col->data<T>());
} }
}; };
...@@ -200,18 +200,18 @@ template <class T> ...@@ -200,18 +200,18 @@ template <class T>
class Col2VolFunctor<platform::GPUPlace, T> { class Col2VolFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col, const framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w, const std::vector<int>& dilations,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& strides,
int padding_depth, int padding_height, const std::vector<int>& paddings,
int padding_width) const { framework::Tensor* vol) const {
PADDLE_ENFORCE(vol.dims().size() == 4); PADDLE_ENFORCE(vol->dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7); PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0]; int input_channels = vol->dims()[0];
int input_depth = vol.dims()[1]; int input_depth = vol->dims()[1];
int input_height = vol.dims()[2]; int input_height = vol->dims()[2];
int input_width = vol.dims()[3]; int input_width = vol->dims()[3];
int filter_depth = col.dims()[1]; int filter_depth = col.dims()[1];
int filter_height = col.dims()[2]; int filter_height = col.dims()[2];
int filter_width = col.dims()[3]; int filter_width = col.dims()[3];
...@@ -219,23 +219,23 @@ class Col2VolFunctor<platform::GPUPlace, T> { ...@@ -219,23 +219,23 @@ class Col2VolFunctor<platform::GPUPlace, T> {
int output_height = col.dims()[5]; int output_height = col.dims()[5];
int output_width = col.dims()[6]; int output_width = col.dims()[6];
PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] -
((dilation_d * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
stride_depth + strides[0] +
1, 1,
output_depth, output_depth,
"input_depth and output_depth are " "input_depth and output_depth are "
"Mismatching."); "Mismatching.");
PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] -
((dilation_h * (filter_height - 1) + 1))) / ((dilations[1] * (filter_height - 1) + 1))) /
stride_height + strides[1] +
1, 1,
output_height, output_height,
"input_height and output_height are " "input_height and output_height are "
"Mismatching."); "Mismatching.");
PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] -
((dilation_w * (filter_width - 1) + 1))) / ((dilations[2] * (filter_width - 1) + 1))) /
stride_width + strides[2] +
1, 1,
output_width, output_width,
"input_width and output_width are " "input_width and output_width are "
...@@ -250,10 +250,10 @@ class Col2VolFunctor<platform::GPUPlace, T> { ...@@ -250,10 +250,10 @@ class Col2VolFunctor<platform::GPUPlace, T> {
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width, num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilation_d, dilation_h, dilation_w, filter_depth, filter_height, dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, stride_depth, stride_height, stride_width, padding_depth, filter_width, strides[0], strides[1], strides[2], paddings[0],
padding_height, padding_width, output_depth, output_height, paddings[1], paddings[2], output_depth, output_height, output_width,
output_width, vol.data<T>()); vol->data<T>());
} }
}; };
......
...@@ -31,6 +31,15 @@ namespace math { ...@@ -31,6 +31,15 @@ namespace math {
* \param colData Column data. * \param colData Column data.
* \param colShape The shape of colData. * \param colShape The shape of colData.
* *
* \param dilations dilation data.
* \param 3-dimension [dilation_depth, dilation_height, dilation_width].
*
* \param strides stride data.
* \param 3-dimension [stride_depth, stride_height, stride_width].
*
* \param paddings padding data.
* \param 3-dimension [d_pad, h_pad, w_pad].
*
* The shape of colData is: * The shape of colData is:
* [input_channels, filter_depth, filter_height, filter_width, output_depth, * [input_channels, filter_depth, filter_height, filter_width, output_depth,
* output_height, output_width] * output_height, output_width]
...@@ -57,22 +66,22 @@ template <typename Place, typename T> ...@@ -57,22 +66,22 @@ template <typename Place, typename T>
class Vol2ColFunctor { class Vol2ColFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col, const framework::Tensor& vol,
int dilation_d, int dilation_h, int dilation_w, const std::vector<int>& dilations,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& strides,
int padding_depth, int padding_height, const std::vector<int>& paddings,
int padding_width) const; framework::Tensor* col) const;
}; };
template <typename Place, typename T> template <typename Place, typename T>
class Col2VolFunctor { class Col2VolFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col, const framework::Tensor& col,
int dilation_d, int dilation_h, int dilation_w, const std::vector<int>& dilations,
int stride_depth, int stride_height, int stride_width, const std::vector<int>& strides,
int padding_depth, int padding_height, const std::vector<int>& paddings,
int padding_width) const; framework::Tensor* vol) const;
}; };
} // namespace math } // namespace math
......
...@@ -62,12 +62,15 @@ void testVol2col() { ...@@ -62,12 +62,15 @@ void testVol2col() {
int input_height = 2; int input_height = 2;
int input_width = 3; int input_width = 3;
int filter_size = 2; int filter_size = 2;
int stride = 1; std::vector<int> strides({1, 1, 1});
int padding = 0; std::vector<int> paddings({0, 0, 0});
int dilation = 1; std::vector<int> dilations({1, 1, 1});
int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; int output_depth =
int output_height = (input_height - filter_size + 2 * padding) / stride + 1; (input_depth - filter_size + 2 * paddings[0]) / strides[0] + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1; int output_height =
(input_height - filter_size + 2 * paddings[1]) / strides[1] + 1;
int output_width =
(input_width - filter_size + 2 * paddings[2]) / strides[2] + 1;
// Vol2Col test // Vol2Col test
float* input_ptr = float* input_ptr =
...@@ -86,8 +89,7 @@ void testVol2col() { ...@@ -86,8 +89,7 @@ void testVol2col() {
*place); *place);
paddle::operators::math::Vol2ColFunctor<Place, float> vol2col; paddle::operators::math::Vol2ColFunctor<Place, float> vol2col;
vol2col(*context, input, output, dilation, dilation, dilation, stride, stride, vol2col(*context, input, dilations, strides, paddings, &output);
stride, padding, padding, padding);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
float* out_cfo_ptr; float* out_cfo_ptr;
...@@ -112,8 +114,7 @@ void testVol2col() { ...@@ -112,8 +114,7 @@ void testVol2col() {
} }
paddle::operators::math::Col2VolFunctor<Place, float> col2vol; paddle::operators::math::Col2VolFunctor<Place, float> col2vol;
col2vol(*context, input, output, dilation, dilation, dilation, stride, stride, col2vol(*context, output, dilations, strides, paddings, &input);
stride, padding, padding, padding);
float* in_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册