提交 f302c6a3 编写于 作者: C chengduoZH

write conv2d and conv3d together

上级 ba7db29d
...@@ -41,8 +41,8 @@ namespace ops = paddle::operators; ...@@ -41,8 +41,8 @@ namespace ops = paddle::operators;
REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(conv_cudnn,
conv_cudnn, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>); ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv_cudnn_grad, conv_cudnn_grad,
ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>); ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
...@@ -198,12 +198,12 @@ namespace ops = paddle::operators; ...@@ -198,12 +198,12 @@ namespace ops = paddle::operators;
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d, ops::GemmConv3DKernel<paddle::platform::CPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGrad3DKernel<paddle::platform::CPUPlace, float>);
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d, ops::GemmConv3DKernel<paddle::platform::GPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv3d_grad, ops::GemmConvGrad3DKernel<paddle::platform::GPUPlace, float>);
...@@ -62,7 +62,7 @@ class ConvOpGrad : public framework::OperatorWithKernel { ...@@ -62,7 +62,7 @@ class ConvOpGrad : public framework::OperatorWithKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConv2DKernel : public framework::OpKernel<T> { class GemmConvKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
...@@ -77,49 +77,78 @@ class GemmConv2DKernel : public framework::OpKernel<T> { ...@@ -77,49 +77,78 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
int batch_size = input->dims()[0]; const int batch_size = static_cast<int>(input->dims()[0]);
int input_channels = input->dims()[1];
int filter_height = filter.dims()[filter.dims().size() - 2]; // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
int filter_width = filter.dims()[filter.dims().size() - 1]; std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
int output_channels = output->dims()[1]; filter_shape_vec.erase(filter_shape_vec.begin(),
int output_height = output->dims()[2]; filter_shape_vec.begin() + 2);
int output_width = output->dims()[3];
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
output_shape_vec.erase(output_shape_vec.begin(),
output_shape_vec.begin() + 2);
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
// use col_shape in the im2col calculation // use col_shape in the im2col calculation
framework::DDim col_shape = {input_channels / groups, filter_height, // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
filter_width, output_height, output_width}; // o_h, o_w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(input->dims()[1] / groups);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
output_shape_vec.end());
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = { // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
input_channels / groups * filter_height * filter_width, // o_h * o_w)
output_height * output_width}; framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col, // col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape // but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix = col; Tensor col_matrix;
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2], framework::DDim input_shape = framework::slice_ddim(
input->dims()[3]}; input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {output_channels, framework::DDim output_matrix_shape = {
output_height * output_width}; output->dims()[1],
// convolution operator: im2col + gemm output->numel() / (output->dims()[0] * output->dims()[1])};
int in_step = input_channels / groups;
int out_step = output_channels / groups; // convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
// im2col
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0], strides[1],
paddings[0], paddings[0], paddings[1], paddings[1]); if (filter_shape_vec.size() == 2) {
// im2col
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
} else if (filter_shape_vec.size() == 3) {
// vol2col
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
}
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
...@@ -132,7 +161,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> { ...@@ -132,7 +161,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConvGrad2DKernel : public framework::OpKernel<T> { class GemmConvGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
...@@ -142,267 +171,74 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> { ...@@ -142,267 +171,74 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Input")); context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter")); context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations, // The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
if (!input_grad && !filter_grad) return;
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
int batch_size = input->dims()[0]; const int batch_size = static_cast<int>(input->dims()[0]);
int input_channels = input->dims()[1];
int filter_height = filter.dims()[filter.dims().size() - 2];
int filter_width = filter.dims()[filter.dims().size() - 1];
int output_channels = output_grad->dims()[1];
int output_height = output_grad->dims()[2];
int output_width = output_grad->dims()[3];
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {input_channels / groups, filter_height,
filter_width, output_height, output_width};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {
input_channels / groups * filter_height * filter_width,
output_height * output_width};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2],
input->dims()[3]};
framework::DDim output_matrix_shape = {
output_grad->dims()[1],
output_grad->dims()[2] * output_grad->dims()[3]};
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
// convolution backward input operator: gemm + col2im
// convolution backward weight operator: im2col + gemm
int in_step = input_channels / groups;
int out_step = output_channels / groups;
math::SetConstant<Place, T> set_zero;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
// col2im
Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
}
}
}
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
set_zero(context.device_context(), filter_grad, static_cast<T>(0));
for (int i = 0; i < batch_size; i++) { // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
Tensor out_grad_batch = std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
output_grad->Slice(i, i + 1).Resize(output_matrix_shape); filter_shape_vec.erase(filter_shape_vec.begin(),
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); filter_shape_vec.begin() + 2);
for (int g = 0; g < groups; g++) {
// im2col
Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
}
}
};
template <typename Place, typename T> // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
class GemmConv3DKernel : public framework::OpKernel<T> { std::vector<int64_t> output_shape_vec(
public: framework::vectorize(output_grad->dims()));
void Compute(const framework::ExecutionContext& context) const override { output_shape_vec.erase(output_shape_vec.begin(),
const Tensor* input = context.Input<Tensor>("Input"); output_shape_vec.begin() + 2);
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); // use col_shape in the im2col calculation
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
int groups = context.Attr<int>("groups"); // o_h, o_w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(input->dims()[1] / groups);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
output_shape_vec.end());
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
int batch_size = input->dims()[0];
int input_channels = input->dims()[1];
int filter_depth = filter.dims()[filter.dims().size() - 3];
int filter_height = filter.dims()[filter.dims().size() - 2];
int filter_width = filter.dims()[filter.dims().size() - 1];
int output_channels = output->dims()[1];
int output_depth = output->dims()[2];
int output_height = output->dims()[3];
int output_width = output->dims()[4];
math::Vol2ColFunctor<Place, T> vol2col;
// use col_shape in the vol2col calculation
framework::DDim col_shape = {input_channels / groups,
filter_depth,
filter_height,
filter_width,
output_depth,
output_height,
output_width};
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = { // size: (i_c/g * k_h * k_w, o_h * o_w)
input_channels / groups * filter_depth * filter_height * filter_width, // or
output_depth * output_height * output_width}; // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
Tensor col; framework::DDim col_matrix_shape =
col.mutable_data<T>(col_shape, context.GetPlace()); framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape framework::DDim input_shape = framework::slice_ddim(
// to call the matrix multiplication interface. input->dims(), 1, static_cast<int>(input->dims().size()));
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = { framework::DDim filter_matrix_shape = {filter.dims()[0],
input->dims()[1], input->dims()[2], input->dims()[3], filter.numel() / filter.dims()[0]};
input->dims()[4]}; // channel, depth, height, width
framework::DDim filter_matrix_shape = {
filter.dims()[0],
filter.numel() / filter.dims()[0]}; // filter_out_channel,
// filter_in_channel*filter_depth*filter_height*filter_width
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = { framework::DDim output_matrix_shape = {
output_channels, output_depth * output_height * output_width}; output_grad->dims()[1],
output_grad->numel() /
// convolution operator: vol2col + gemm (output_grad->dims()[0] * output_grad->dims()[1])};
int in_step = input_channels / groups;
int out_step = output_channels / groups;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
// vol2col
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
vol2col(context.device_context(), in_slice, col, strides[0], strides[1],
strides[2], paddings[0], paddings[1], paddings[2]);
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
}
}
}
};
template <typename Place, typename T>
class GemmConvGrad3DKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); // convolution backward input operator: gemm + col2im(or col2vol)
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); // convolution backward weight operator: im2col(or vol2col) + gemm
int groups = context.Attr<int>("groups"); int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
int batch_size = input->dims()[0];
int input_channels = input->dims()[1];
int filter_depth = filter.dims()[filter.dims().size() - 3];
int filter_height = filter.dims()[filter.dims().size() - 2];
int filter_width = filter.dims()[filter.dims().size() - 1];
int output_channels = output_grad->dims()[1];
int output_depth = output_grad->dims()[2];
int output_height = output_grad->dims()[3];
int output_width = output_grad->dims()[4];
math::Col2VolFunctor<Place, T> col2vol;
math::Vol2ColFunctor<Place, T> vol2col;
// use col_shape in the vol2col and col2vol calculation
framework::DDim col_shape = {input_channels / groups,
filter_depth,
filter_height,
filter_width,
output_depth,
output_height,
output_width};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {
input_channels / groups * filter_depth * filter_height * filter_width,
output_depth * output_height * output_width};
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col, // col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape // but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix = col; Tensor col_matrix;
col.mutable_data<T>(col_shape, context.GetPlace());
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {
input->dims()[1], input->dims()[2], input->dims()[3],
input->dims()[4]}; // channel, depth, height, width
framework::DDim output_matrix_shape = {output_grad->dims()[1],
output_grad->dims()[2] *
output_grad->dims()[3] *
output_grad->dims()[4]};
framework::DDim filter_matrix_shape = {
filter.dims()[0],
filter.numel() / filter.dims()[0]}; // filter_out_channel,
// filter_in_channel*filter_depth*filter_height*filter_width
filter.Resize(filter_matrix_shape);
// convolution backward input operator: gemm + col2vol
// convolution backward weight operator: vol2col + gemm
int in_step = input_channels / groups;
int out_step = output_channels / groups;
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
if (input_grad) { if (input_grad) {
...@@ -421,13 +257,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> { ...@@ -421,13 +257,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
math::matmul<Place, T>(context.device_context(), filter_slice, true, math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix, out_grad_slice, false, T(1.0), &col_matrix,
T(0.0)); T(0.0));
// col2im
// col2vol
Tensor in_grad_slice = Tensor in_grad_slice =
in_grad_batch.Slice(g * in_step, (g + 1) * in_step); in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
col2vol(context.device_context(), in_grad_slice, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1], if (filter_shape_vec.size() == 2) {
paddings[2]); math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Col2VolFunctor<Place, T> col2vol;
col2vol(context.device_context(), in_grad_slice, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
}
} }
} }
} }
...@@ -443,13 +288,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> { ...@@ -443,13 +288,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
output_grad->Slice(i, i + 1).Resize(output_matrix_shape); output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
// vol2col // im2col
Tensor out_grad_slice = Tensor out_grad_slice =
out_grad_batch.Slice(g * out_step, (g + 1) * out_step); out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
vol2col(context.device_context(), in_slice, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1], if (filter_shape_vec.size() == 2) {
paddings[2]); math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
}
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
...@@ -462,6 +316,5 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> { ...@@ -462,6 +316,5 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
} }
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -61,25 +61,23 @@ class TestConv2dOp(OpTest): ...@@ -61,25 +61,23 @@ class TestConv2dOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05) set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
self.check_grad( self.check_grad(
['Input'], ['Input'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
self.check_grad( self.check_grad(
['Filter'], ['Filter'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
# self.groups = 1
# self.op_type = "conv2d"
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.dilations = [1, 1] self.dilations = [1, 1]
......
...@@ -64,20 +64,20 @@ class TestConv3dOp(OpTest): ...@@ -64,20 +64,20 @@ class TestConv3dOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05) set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
self.check_grad( self.check_grad(
['Input'], ['Input'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.03,
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
self.check_grad( self.check_grad(
['Filter'], ['Filter'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.03,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册