conv2d_op.h 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

H
hedaoyuan 已提交
17
#include "paddle/framework/eigen.h"
18 19 20 21 22 23 24 25 26
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

武毅 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int padding,
                      int stride) {
  int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
  return output_size;
}

// Define Op classes in .h file so that other conv
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  Conv2DOpMaker(framework::OpProto* proto,
                framework::OpAttrChecker* op_checker);
};

class Conv2DOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override;
};

class Conv2DOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override;
};

59
template <typename Place, typename T>
Y
Yu Yang 已提交
60
class GemmConv2DKernel : public framework::OpKernel<T> {
61 62 63
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
H
hedaoyuan 已提交
64 65 66 67
    // The filter will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
68 69 70 71 72
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
H
hedaoyuan 已提交
73
    int groups = context.Attr<int>("groups");
74 75 76

    int batch_size = input->dims()[0];
    int input_channels = input->dims()[1];
H
hedaoyuan 已提交
77 78 79
    int filter_height = filter.dims()[filter.dims().size() - 2];
    int filter_width = filter.dims()[filter.dims().size() - 1];
    int output_channels = output->dims()[1];
80 81 82 83 84 85
    int output_height = output->dims()[2];
    int output_width = output->dims()[3];

    paddle::operators::math::Im2ColFunctor<
        paddle::operators::math::ColFormat::kCFO, Place, T>
        im2col;
H
hedaoyuan 已提交
86
    // use col_shape in the im2col calculation
H
hedaoyuan 已提交
87 88
    framework::DDim col_shape = {input_channels / groups, filter_height,
                                 filter_width, output_height, output_width};
H
hedaoyuan 已提交
89 90
    // use col_matrix_shape in the gemm calculation
    framework::DDim col_matrix_shape = {
H
hedaoyuan 已提交
91
        input_channels / groups * filter_height * filter_width,
H
hedaoyuan 已提交
92
        output_height * output_width};
H
hedaoyuan 已提交
93
    Tensor col;
H
hedaoyuan 已提交
94
    col.mutable_data<T>(col_shape, context.GetPlace());
H
hedaoyuan 已提交
95 96 97 98 99
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix = col;
    col_matrix.Resize(col_matrix_shape);
100 101 102

    framework::DDim input_shape = {input->dims()[1], input->dims()[2],
                                   input->dims()[3]};
H
hedaoyuan 已提交
103 104
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
H
hedaoyuan 已提交
105 106 107 108
    filter.Resize(filter_matrix_shape);

    framework::DDim output_matrix_shape = {output_channels,
                                           output_height * output_width};
H
hedaoyuan 已提交
109
    // convolution operator: im2col + gemm
H
hedaoyuan 已提交
110 111
    int in_step = input_channels / groups;
    int out_step = output_channels / groups;
112
    for (int i = 0; i < batch_size; i++) {
113 114
      Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_shape);
      Tensor out_batch = output->Slice<T>(i, i + 1).Resize(output_matrix_shape);
H
hedaoyuan 已提交
115 116
      for (int g = 0; g < groups; g++) {
        // im2col
117
        Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
H
hedaoyuan 已提交
118 119
        im2col(context.device_context(), in_slice, col, strides[0], strides[1],
               paddings[0], paddings[1]);
H
hedaoyuan 已提交
120 121

        // gemm
122
        Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
H
hedaoyuan 已提交
123
        Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step);
H
hedaoyuan 已提交
124 125
        math::matmul<Place, T>(context.device_context(), filter_slice, false,
                               col_matrix, false, T(1.0), &out_slice, T(0.0));
H
hedaoyuan 已提交
126
      }
127 128 129 130 131
    }
  }
};

template <typename Place, typename T>
Y
Yu Yang 已提交
132
class GemmConvGrad2DKernel : public framework::OpKernel<T> {
133 134
 public:
  void Compute(const framework::ExecutionContext& context) const override {
H
hedaoyuan 已提交
135 136 137 138 139
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
H
hedaoyuan 已提交
140
    Tensor* filter_grad =
H
hedaoyuan 已提交
141
        context.Output<Tensor>(framework::GradVarName("Filter"));
H
hedaoyuan 已提交
142 143 144 145 146

    // The filter and filter_grad will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
H
hedaoyuan 已提交
147 148 149

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
150
    int groups = context.Attr<int>("groups");
H
hedaoyuan 已提交
151 152 153

    int batch_size = input->dims()[0];
    int input_channels = input->dims()[1];
H
hedaoyuan 已提交
154 155
    int filter_height = filter.dims()[filter.dims().size() - 2];
    int filter_width = filter.dims()[filter.dims().size() - 1];
156
    int output_channels = output_grad->dims()[1];
H
hedaoyuan 已提交
157 158 159 160 161 162 163 164 165
    int output_height = output_grad->dims()[2];
    int output_width = output_grad->dims()[3];

    paddle::operators::math::Col2ImFunctor<
        paddle::operators::math::ColFormat::kCFO, Place, T>
        col2im;
    paddle::operators::math::Im2ColFunctor<
        paddle::operators::math::ColFormat::kCFO, Place, T>
        im2col;
H
hedaoyuan 已提交
166
    // use col_shape in the im2col and col2im calculation
167 168
    framework::DDim col_shape = {input_channels / groups, filter_height,
                                 filter_width, output_height, output_width};
H
hedaoyuan 已提交
169 170
    // use col_matrix_shape in the gemm calculation
    framework::DDim col_matrix_shape = {
171
        input_channels / groups * filter_height * filter_width,
H
hedaoyuan 已提交
172 173
        output_height * output_width};
    Tensor col;
H
hedaoyuan 已提交
174
    col.mutable_data<T>(col_shape, context.GetPlace());
H
hedaoyuan 已提交
175 176 177 178 179
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix = col;
    col_matrix.Resize(col_matrix_shape);
H
hedaoyuan 已提交
180 181 182 183 184 185 186

    framework::DDim input_shape = {input->dims()[1], input->dims()[2],
                                   input->dims()[3]};
    framework::DDim output_matrix_shape = {
        output_grad->dims()[1],
        output_grad->dims()[2] * output_grad->dims()[3]};

H
hedaoyuan 已提交
187 188
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
H
hedaoyuan 已提交
189 190
    filter.Resize(filter_matrix_shape);

H
hedaoyuan 已提交
191 192
    // convolution backward input operator:  gemm + col2im
    // convolution backward weight operator: im2col + gemm
193 194
    int in_step = input_channels / groups;
    int out_step = output_channels / groups;
H
hedaoyuan 已提交
195 196 197 198 199 200 201 202 203 204 205 206

    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
      auto t = framework::EigenVector<T>::Flatten(*input_grad);
      t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));

      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
            output_grad->Slice<T>(i, i + 1).Resize(output_matrix_shape);
        Tensor in_grad_batch =
            input_grad->Slice<T>(i, i + 1).Resize(input_shape);
        for (int g = 0; g < groups; g++) {
207
          // gemm
H
hedaoyuan 已提交
208 209
          Tensor out_grad_slice =
              out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
210 211
          Tensor filter_slice =
              filter.Slice<T>(g * out_step, (g + 1) * out_step);
H
hedaoyuan 已提交
212
          math::matmul<Place, T>(context.device_context(), filter_slice, true,
H
hedaoyuan 已提交
213 214
                                 out_grad_slice, false, T(1.0), &col_matrix,
                                 T(0.0));
215 216 217 218

          // col2im
          Tensor in_grad_slice =
              in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step);
H
hedaoyuan 已提交
219 220
          col2im(context.device_context(), in_grad_slice, col, strides[0],
                 strides[1], paddings[0], paddings[1]);
221
        }
H
hedaoyuan 已提交
222 223
      }
    }
224

H
hedaoyuan 已提交
225 226 227 228 229 230 231 232 233 234 235 236
    if (filter_grad) {
      filter_grad->mutable_data<T>(context.GetPlace());
      Tensor filter_grad_ = *filter_grad;
      filter_grad_.Resize(filter_matrix_shape);
      auto t = framework::EigenVector<T>::Flatten(filter_grad_);
      t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));

      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
            output_grad->Slice<T>(i, i + 1).Resize(output_matrix_shape);
        Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_shape);
        for (int g = 0; g < groups; g++) {
237
          // im2col
H
hedaoyuan 已提交
238 239
          Tensor out_grad_slice =
              out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
240
          Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
H
hedaoyuan 已提交
241 242
          im2col(context.device_context(), in_slice, col, strides[0],
                 strides[1], paddings[0], paddings[1]);
243 244 245

          // gemm
          Tensor filter_grad_slice =
H
hedaoyuan 已提交
246
              filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step);
H
hedaoyuan 已提交
247 248 249
          math::matmul<Place, T>(context.device_context(), out_grad_slice,
                                 false, col_matrix, true, T(1.0),
                                 &filter_grad_slice, T(1.0));
250
        }
251
      }
H
hedaoyuan 已提交
252
    }
253 254 255 256 257
  }
};

}  // namespace operators
}  // namespace paddle