conv_op.h 14.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

H
hedaoyuan 已提交
17
#include "paddle/framework/eigen.h"
18 19 20
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
C
chengduoZH 已提交
21
#include "paddle/operators/math/vol2col.h"
22 23 24 25 26 27

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

武毅 已提交
28 29
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
C
chengduoZH 已提交
30
inline int OutputSize(int input_size, int filter_size, int dilation,
C
chengduoZH 已提交
31 32 33
                      int padding, int stride) {
  const int dkernel = dilation * (filter_size - 1) + 1;
  const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
武毅 已提交
34 35
  return output_size;
}
C
chengduoZH 已提交
36 37 38
inline bool IsExpand(std::vector<int64_t>& filter_dim,
                     std::vector<int>& strides, std::vector<int>& paddings,
                     std::vector<int>& dilations) {
C
chengduoZH 已提交
39 40
  bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
  for (size_t j = 0; j < strides.size(); ++j) {
C
chengduoZH 已提交
41 42 43 44
    filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
    strides_1 = strides_1 && (strides[j] == 1);
    padding_0 = padding_0 && (paddings[j] == 0);
    dilation_1 = dilation_1 && (dilations[j] == 1);
C
chengduoZH 已提交
45
  }
C
chengduoZH 已提交
46
  return !(filter_1 && strides_1 && padding_0 && dilation_1);
C
chengduoZH 已提交
47
}
武毅 已提交
48 49 50 51 52 53 54 55 56

// Define Op classes in .h file so that other conv
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  Conv2DOpMaker(framework::OpProto* proto,
                framework::OpAttrChecker* op_checker);
};

C
chengduoZH 已提交
57 58 59 60 61 62 63
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  Conv3DOpMaker(framework::OpProto* proto,
                framework::OpAttrChecker* op_checker);
};

class ConvOp : public framework::OperatorWithKernel {
武毅 已提交
64 65 66 67 68
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
};

C
chengduoZH 已提交
69
class ConvOpGrad : public framework::OperatorWithKernel {
武毅 已提交
70 71 72 73 74
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
};

75
template <typename Place, typename T>
C
chengduoZH 已提交
76
class GemmConvKernel : public framework::OpKernel<T> {
77 78 79
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
H
hedaoyuan 已提交
80 81 82 83
    // The filter will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
84 85 86
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

C
chengduoZH 已提交
87
    int groups = context.Attr<int>("groups");
88 89
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
90
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
91

C
chengduoZH 已提交
92 93 94 95 96 97 98 99 100 101 102
    const int batch_size = static_cast<int>(input->dims()[0]);

    // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
    std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
    filter_shape_vec.erase(filter_shape_vec.begin(),
                           filter_shape_vec.begin() + 2);

    // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
    std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
    output_shape_vec.erase(output_shape_vec.begin(),
                           output_shape_vec.begin() + 2);
103

H
hedaoyuan 已提交
104
    // use col_shape in the im2col calculation
C
chengduoZH 已提交
105 106 107 108 109 110 111 112 113 114
    // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
    // o_h, o_w}
    std::vector<int64_t> col_shape_vec;
    col_shape_vec.push_back(input->dims()[1] / groups);
    col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
                         filter_shape_vec.end());
    col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
                         output_shape_vec.end());
    framework::DDim col_shape(framework::make_ddim(col_shape_vec));

H
hedaoyuan 已提交
115
    // use col_matrix_shape in the gemm calculation
C
chengduoZH 已提交
116 117 118 119 120
    // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
    // o_h * o_w)
    framework::DDim col_matrix_shape =
        framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);

C
chengduoZH 已提交
121
    bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
H
hedaoyuan 已提交
122
    Tensor col;
H
hedaoyuan 已提交
123 124 125
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
C
chengduoZH 已提交
126
    Tensor col_matrix;
C
chengduoZH 已提交
127
    if (is_expand) {
C
chengduoZH 已提交
128 129 130 131
      col.mutable_data<T>(col_shape, context.GetPlace());
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }
132

C
chengduoZH 已提交
133 134 135
    framework::DDim input_shape = framework::slice_ddim(
        input->dims(), 1, static_cast<int>(input->dims().size()));

H
hedaoyuan 已提交
136 137
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
H
hedaoyuan 已提交
138 139
    filter.Resize(filter_matrix_shape);

C
chengduoZH 已提交
140 141 142 143 144 145 146 147
    framework::DDim output_matrix_shape = {
        output->dims()[1],
        output->numel() / (output->dims()[0] * output->dims()[1])};

    // convolution operator: im2col(or vol2col) + gemm
    int in_step = static_cast<int>(input->dims()[1]) / groups;
    int out_step = static_cast<int>(output->dims()[1]) / groups;

C
chengduoZH 已提交
148 149
    math::Vol2ColFunctor<Place, T> vol2col;
    math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
C
chengduoZH 已提交
150

C
chengduoZH 已提交
151 152 153
    for (int i = 0; i < batch_size; i++) {
      Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
      Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
C
chengduoZH 已提交
154

C
chengduoZH 已提交
155 156
      for (int g = 0; g < groups; g++) {
        Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
H
hedaoyuan 已提交
157

C
chengduoZH 已提交
158
        if (!is_expand) {
C
chengduoZH 已提交
159 160 161
          col.ShareDataWith(in_slice);
          col_matrix.ShareDataWith(col);
          col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
162 163 164 165 166 167 168 169 170 171
        } else if (filter_shape_vec.size() == 2) {
          // im2col
          im2col(context.device_context(), in_slice, dilations, strides,
                 std::vector<int>{paddings[0], paddings[1], paddings[0],
                                  paddings[1]},
                 &col);
        } else if (filter_shape_vec.size() == 3) {
          // vol2col
          vol2col(context.device_context(), in_slice, dilations, strides,
                  paddings, &col);
C
chengduoZH 已提交
172
        }
C
chengduoZH 已提交
173 174 175 176 177 178

        // gemm
        Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
        Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
        math::matmul<Place, T>(context.device_context(), filter_slice, false,
                               col_matrix, false, T(1.0), &out_slice, T(0.0));
H
hedaoyuan 已提交
179
      }
180 181 182 183 184
    }
  }
};

template <typename Place, typename T>
C
chengduoZH 已提交
185
class GemmConvGradKernel : public framework::OpKernel<T> {
186 187
 public:
  void Compute(const framework::ExecutionContext& context) const override {
H
hedaoyuan 已提交
188 189 190 191 192
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
H
hedaoyuan 已提交
193
    Tensor* filter_grad =
H
hedaoyuan 已提交
194
        context.Output<Tensor>(framework::GradVarName("Filter"));
H
hedaoyuan 已提交
195 196 197 198
    // The filter and filter_grad will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
H
hedaoyuan 已提交
199

C
chengduoZH 已提交
200 201
    if (!input_grad && !filter_grad) return;

C
chengduoZH 已提交
202
    int groups = context.Attr<int>("groups");
H
hedaoyuan 已提交
203 204
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
205
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
H
hedaoyuan 已提交
206

C
chengduoZH 已提交
207
    const int batch_size = static_cast<int>(input->dims()[0]);
H
hedaoyuan 已提交
208

C
chengduoZH 已提交
209 210 211 212
    // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
    std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
    filter_shape_vec.erase(filter_shape_vec.begin(),
                           filter_shape_vec.begin() + 2);
213

C
chengduoZH 已提交
214 215 216 217 218
    // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
    std::vector<int64_t> output_shape_vec(
        framework::vectorize(output_grad->dims()));
    output_shape_vec.erase(output_shape_vec.begin(),
                           output_shape_vec.begin() + 2);
C
chengduoZH 已提交
219

C
chengduoZH 已提交
220 221 222 223 224 225 226 227 228 229
    // use col_shape in the im2col calculation
    // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
    // o_h, o_w}
    std::vector<int64_t> col_shape_vec;
    col_shape_vec.push_back(input->dims()[1] / groups);
    col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
                         filter_shape_vec.end());
    col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
                         output_shape_vec.end());
    framework::DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
230 231

    // use col_matrix_shape in the gemm calculation
C
chengduoZH 已提交
232 233 234 235 236 237 238 239
    // size: (i_c/g * k_h * k_w, o_h * o_w)
    // or
    // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
    framework::DDim col_matrix_shape =
        framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);

    framework::DDim input_shape = framework::slice_ddim(
        input->dims(), 1, static_cast<int>(input->dims().size()));
C
chengduoZH 已提交
240

C
chengduoZH 已提交
241 242
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
C
chengduoZH 已提交
243 244 245
    filter.Resize(filter_matrix_shape);

    framework::DDim output_matrix_shape = {
C
chengduoZH 已提交
246 247 248
        output_grad->dims()[1],
        output_grad->numel() /
            (output_grad->dims()[0] * output_grad->dims()[1])};
C
chengduoZH 已提交
249

C
chengduoZH 已提交
250 251 252 253
    // convolution backward input operator:  gemm + col2im(or col2vol)
    // convolution backward weight operator: im2col(or vol2col) + gemm
    int in_step = static_cast<int>(input->dims()[1]) / groups;
    int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
C
chengduoZH 已提交
254

C
chengduoZH 已提交
255
    bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
C
chengduoZH 已提交
256 257 258 259
    Tensor col;
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
C
chengduoZH 已提交
260
    Tensor col_matrix;
C
chengduoZH 已提交
261
    if (is_expand) {
C
chengduoZH 已提交
262 263 264 265
      col.mutable_data<T>(col_shape, context.GetPlace());
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }
C
chengduoZH 已提交
266

C
chengduoZH 已提交
267
    math::SetConstant<Place, T> set_zero;
C
chengduoZH 已提交
268 269 270

    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
271
      set_zero(context.device_context(), input_grad, static_cast<T>(0));
C
chengduoZH 已提交
272

C
chengduoZH 已提交
273 274
      math::Col2VolFunctor<Place, T> col2vol;
      math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
C
chengduoZH 已提交
275

C
chengduoZH 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289
      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
        for (int g = 0; g < groups; g++) {
          // gemm
          Tensor out_grad_slice =
              out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);

          Tensor in_grad_slice =
              in_grad_batch.Slice(g * in_step, (g + 1) * in_step);

          if (!is_expand) {
C
chengduoZH 已提交
290 291
            col_matrix.ShareDataWith(in_grad_slice);
            col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304
          }
          math::matmul<Place, T>(context.device_context(), filter_slice, true,
                                 out_grad_slice, false, T(1.0), &col_matrix,
                                 T(0.0));

          if (is_expand && filter_shape_vec.size() == 2) {
            col2im(context.device_context(), col, dilations, strides,
                   std::vector<int>{paddings[0], paddings[1], paddings[0],
                                    paddings[1]},
                   &in_grad_slice);
          } else if (is_expand && filter_shape_vec.size() == 3) {
            col2vol(context.device_context(), col, dilations, strides, paddings,
                    &in_grad_slice);
C
chengduoZH 已提交
305
          }
C
chengduoZH 已提交
306 307 308 309 310 311 312 313
        }
      }
    }

    if (filter_grad) {
      filter_grad->mutable_data<T>(context.GetPlace());
      Tensor filter_grad_ = *filter_grad;
      filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
314
      set_zero(context.device_context(), filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
315 316 317 318 319 320 321 322 323 324 325
      math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
      math::Vol2ColFunctor<Place, T> vol2col;
      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
        for (int g = 0; g < groups; g++) {
          // im2col
          Tensor out_grad_slice =
              out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
C
chengduoZH 已提交
326

C
chengduoZH 已提交
327
          if (!is_expand) {
C
chengduoZH 已提交
328 329 330
            col.ShareDataWith(in_slice);
            col_matrix.ShareDataWith(col);
            col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
331 332 333 334 335 336 337 338
          } else if (filter_shape_vec.size() == 2) {
            im2col(context.device_context(), in_slice, dilations, strides,
                   std::vector<int>{paddings[0], paddings[1], paddings[0],
                                    paddings[1]},
                   &col);
          } else if (filter_shape_vec.size() == 3) {
            vol2col(context.device_context(), in_slice, dilations, strides,
                    paddings, &col);
C
chengduoZH 已提交
339
          }
C
chengduoZH 已提交
340 341 342 343 344 345 346

          // gemm
          Tensor filter_grad_slice =
              filter_grad_.Slice(g * out_step, (g + 1) * out_step);
          math::matmul<Place, T>(context.device_context(), out_grad_slice,
                                 false, col_matrix, true, T(1.0),
                                 &filter_grad_slice, T(1.0));
C
chengduoZH 已提交
347 348 349 350 351
        }
      }
    }
  }
};
352 353
}  // namespace operators
}  // namespace paddle