conv_op.h 17.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

Q
qingqing01 已提交
17
#include <string>
18
#include <vector>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Y
Yu Yang 已提交
21
#include "paddle/fluid/operators/math/blas.h"
Y
Yi Wang 已提交
22 23 24
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
25 26 27 28 29

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
X
Xin Pan 已提交
30 31
constexpr int kConvMKLDNNFP32 = 1;
constexpr int kConvMKLDNNINT8 = 2;
32

武毅 已提交
33 34
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
Y
Yang Yang 已提交
35 36
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
                          int padding, int stride) {
C
chengduoZH 已提交
37
  const int dkernel = dilation * (filter_size - 1) + 1;
C
chengduoZH 已提交
38 39 40 41 42 43 44 45
  int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
  PADDLE_ENFORCE(
      output_size > 0,
      "Due to the settings of padding(%d), filter_size(%d), dilation(%d) and "
      "stride(%d), the output size is less than 0, please check "
      "again. Input_size:%d",
      padding, filter_size, dilation, stride, input_size);

武毅 已提交
46 47
  return output_size;
}
48 49 50 51
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
                     const std::vector<int>& strides,
                     const std::vector<int>& paddings,
                     const std::vector<int>& dilations) {
C
chengduoZH 已提交
52 53
  bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
  for (size_t j = 0; j < strides.size(); ++j) {
C
chengduoZH 已提交
54
    filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
C
chengduoZH 已提交
55 56 57
    strides_1 = strides_1 && (strides[j] == 1);
    padding_0 = padding_0 && (paddings[j] == 0);
    dilation_1 = dilation_1 && (dilations[j] == 1);
C
chengduoZH 已提交
58
  }
C
chengduoZH 已提交
59
  return !(filter_1 && strides_1 && padding_0 && dilation_1);
C
chengduoZH 已提交
60
}
武毅 已提交
61 62 63 64 65

// Define Op classes in .h file so that other conv
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Q
qingqing01 已提交
66 67 68 69
  void Make() final;

 protected:
  virtual void Apply() {}
武毅 已提交
70 71
};

C
chengduoZH 已提交
72 73
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Q
qingqing01 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86
  void Make() final;

 protected:
  virtual void Apply() {}
};

class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{
        {"Input", /*->*/ "Output"}};
  }
C
chengduoZH 已提交
87 88 89
};

class ConvOp : public framework::OperatorWithKernel {
武毅 已提交
90 91 92
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
93 94 95 96

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
武毅 已提交
97 98
};

C
chengduoZH 已提交
99
class ConvOpGrad : public framework::OperatorWithKernel {
武毅 已提交
100 101 102
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
103 104 105 106

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
武毅 已提交
107 108
};

Q
QI JUN 已提交
109
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
110
class GemmConvKernel : public framework::OpKernel<T> {
111 112 113
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
H
hedaoyuan 已提交
114 115 116 117
    // The filter will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
118 119 120
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

C
chengduoZH 已提交
121
    int groups = context.Attr<int>("groups");
122 123
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
124
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
125

C
chengduoZH 已提交
126 127
    const int batch_size = static_cast<int>(input->dims()[0]);

C
chengduoZH 已提交
128
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
C
chengduoZH 已提交
129
    std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
C
chengduoZH 已提交
130
    // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
C
chengduoZH 已提交
131
    std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
132

H
hedaoyuan 已提交
133
    // use col_shape in the im2col calculation
C
chengduoZH 已提交
134 135
    // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
    // o_h, o_w}
C
chengduoZH 已提交
136 137 138 139 140 141 142
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
    col_shape_vec[0] = input->dims()[1] / groups;
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
    }
C
chengduoZH 已提交
143 144
    framework::DDim col_shape(framework::make_ddim(col_shape_vec));

H
hedaoyuan 已提交
145
    // use col_matrix_shape in the gemm calculation
C
chengduoZH 已提交
146 147 148
    // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
    // o_h * o_w)
    framework::DDim col_matrix_shape =
C
chengduoZH 已提交
149
        framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
150

C
chengduoZH 已提交
151
    bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
H
hedaoyuan 已提交
152
    Tensor col;
H
hedaoyuan 已提交
153 154 155
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
C
chengduoZH 已提交
156
    Tensor col_matrix;
C
chengduoZH 已提交
157
    if (is_expand) {
C
chengduoZH 已提交
158 159 160 161
      col.mutable_data<T>(col_shape, context.GetPlace());
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }
162

C
chengduoZH 已提交
163 164 165
    framework::DDim input_shape = framework::slice_ddim(
        input->dims(), 1, static_cast<int>(input->dims().size()));

H
hedaoyuan 已提交
166 167
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
H
hedaoyuan 已提交
168 169
    filter.Resize(filter_matrix_shape);

C
chengduoZH 已提交
170 171 172 173 174 175 176 177
    framework::DDim output_matrix_shape = {
        output->dims()[1],
        output->numel() / (output->dims()[0] * output->dims()[1])};

    // convolution operator: im2col(or vol2col) + gemm
    int in_step = static_cast<int>(input->dims()[1]) / groups;
    int out_step = static_cast<int>(output->dims()[1]) / groups;

Q
QI JUN 已提交
178 179
    math::Vol2ColFunctor<DeviceContext, T> vol2col;
    math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
C
chengduoZH 已提交
180

Q
QI JUN 已提交
181
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
182
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
183 184 185
    for (int i = 0; i < batch_size; i++) {
      Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
      Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
C
chengduoZH 已提交
186

C
chengduoZH 已提交
187 188
      for (int g = 0; g < groups; g++) {
        Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
H
hedaoyuan 已提交
189

C
chengduoZH 已提交
190
        if (!is_expand) {
C
chengduoZH 已提交
191 192 193
          col.ShareDataWith(in_slice);
          col_matrix.ShareDataWith(col);
          col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
194
        } else if (data_dim == 2U) {
C
chengduoZH 已提交
195
          // im2col
Q
QI JUN 已提交
196
          im2col(dev_ctx, in_slice, dilations, strides,
C
chengduoZH 已提交
197 198 199
                 std::vector<int>{paddings[0], paddings[1], paddings[0],
                                  paddings[1]},
                 &col);
C
chengduoZH 已提交
200
        } else if (data_dim == 3U) {
C
chengduoZH 已提交
201
          // vol2col
Q
QI JUN 已提交
202
          vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
C
chengduoZH 已提交
203
        }
C
chengduoZH 已提交
204 205 206 207

        // gemm
        Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
        Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
C
chengduoZH 已提交
208 209
        blas.MatMul(filter_slice, false, col_matrix, false, T(1.0), &out_slice,
                    T(0.0));
H
hedaoyuan 已提交
210
      }
211 212 213 214
    }
  }
};

Q
QI JUN 已提交
215
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
216
class GemmConvGradKernel : public framework::OpKernel<T> {
217 218
 public:
  void Compute(const framework::ExecutionContext& context) const override {
H
hedaoyuan 已提交
219 220 221 222 223
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
H
hedaoyuan 已提交
224
    Tensor* filter_grad =
H
hedaoyuan 已提交
225
        context.Output<Tensor>(framework::GradVarName("Filter"));
H
hedaoyuan 已提交
226 227 228 229
    // The filter and filter_grad will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
H
hedaoyuan 已提交
230

C
chengduoZH 已提交
231 232
    if (!input_grad && !filter_grad) return;

C
chengduoZH 已提交
233
    int groups = context.Attr<int>("groups");
H
hedaoyuan 已提交
234 235
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
236
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
H
hedaoyuan 已提交
237

C
chengduoZH 已提交
238
    const int batch_size = static_cast<int>(input->dims()[0]);
H
hedaoyuan 已提交
239

C
chengduoZH 已提交
240
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
C
chengduoZH 已提交
241
    std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
C
chengduoZH 已提交
242
    // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
C
chengduoZH 已提交
243 244
    std::vector<int64_t> output_shape_vec(
        framework::vectorize(output_grad->dims()));
C
chengduoZH 已提交
245

C
chengduoZH 已提交
246 247 248
    // use col_shape in the im2col calculation
    // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
    // o_h, o_w}
C
chengduoZH 已提交
249 250 251 252 253 254 255
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
    col_shape_vec[0] = input->dims()[1] / groups;
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
    }
C
chengduoZH 已提交
256
    framework::DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
257 258

    // use col_matrix_shape in the gemm calculation
C
chengduoZH 已提交
259 260 261 262
    // size: (i_c/g * k_h * k_w, o_h * o_w)
    // or
    // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
    framework::DDim col_matrix_shape =
C
chengduoZH 已提交
263
        framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
264 265 266

    framework::DDim input_shape = framework::slice_ddim(
        input->dims(), 1, static_cast<int>(input->dims().size()));
C
chengduoZH 已提交
267

C
chengduoZH 已提交
268 269
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
C
chengduoZH 已提交
270 271 272
    filter.Resize(filter_matrix_shape);

    framework::DDim output_matrix_shape = {
C
chengduoZH 已提交
273 274 275
        output_grad->dims()[1],
        output_grad->numel() /
            (output_grad->dims()[0] * output_grad->dims()[1])};
C
chengduoZH 已提交
276

C
chengduoZH 已提交
277 278 279 280
    // convolution backward input operator:  gemm + col2im(or col2vol)
    // convolution backward weight operator: im2col(or vol2col) + gemm
    int in_step = static_cast<int>(input->dims()[1]) / groups;
    int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
C
chengduoZH 已提交
281

C
chengduoZH 已提交
282
    bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
C
chengduoZH 已提交
283 284 285 286
    Tensor col;
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
C
chengduoZH 已提交
287
    Tensor col_matrix;
C
chengduoZH 已提交
288
    if (is_expand) {
C
chengduoZH 已提交
289 290 291 292
      col.mutable_data<T>(col_shape, context.GetPlace());
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }
C
chengduoZH 已提交
293

Q
QI JUN 已提交
294 295
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
296
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
297 298 299 300

    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());

C
chengduoZH 已提交
301 302 303
      // if is_expand is false, the operation of set_zero is unnecessary,
      // because math::matmul will reset input_grad.
      if (is_expand) {
C
chengduoZH 已提交
304
        set_zero(dev_ctx, input_grad, static_cast<T>(0));
C
chengduoZH 已提交
305
      }
Q
QI JUN 已提交
306 307
      math::Col2VolFunctor<DeviceContext, T> col2vol;
      math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
C
chengduoZH 已提交
308

C
chengduoZH 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321 322
      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape);
        for (int g = 0; g < groups; g++) {
          // gemm
          Tensor out_grad_slice =
              out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);

          Tensor in_grad_slice =
              in_grad_batch.Slice(g * in_step, (g + 1) * in_step);

          if (!is_expand) {
C
chengduoZH 已提交
323 324
            col_matrix.ShareDataWith(in_grad_slice);
            col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
325
          }
C
chengduoZH 已提交
326 327
          blas.MatMul(filter_slice, true, out_grad_slice, false, T(1.0),
                      &col_matrix, T(0.0));
C
chengduoZH 已提交
328

C
chengduoZH 已提交
329
          if (is_expand && data_dim == 2U) {
Q
QI JUN 已提交
330
            col2im(dev_ctx, col, dilations, strides,
C
chengduoZH 已提交
331 332 333
                   std::vector<int>{paddings[0], paddings[1], paddings[0],
                                    paddings[1]},
                   &in_grad_slice);
C
chengduoZH 已提交
334
          } else if (is_expand && data_dim == 3U) {
Q
QI JUN 已提交
335
            col2vol(dev_ctx, col, dilations, strides, paddings, &in_grad_slice);
C
chengduoZH 已提交
336
          }
C
chengduoZH 已提交
337 338 339 340 341 342 343 344
        }
      }
    }

    if (filter_grad) {
      filter_grad->mutable_data<T>(context.GetPlace());
      Tensor filter_grad_ = *filter_grad;
      filter_grad_.Resize(filter_matrix_shape);
Q
QI JUN 已提交
345 346 347
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
C
chengduoZH 已提交
348 349 350 351 352 353 354 355 356
      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
        for (int g = 0; g < groups; g++) {
          // im2col
          Tensor out_grad_slice =
              out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
C
chengduoZH 已提交
357

C
chengduoZH 已提交
358
          if (!is_expand) {
C
chengduoZH 已提交
359 360 361
            col.ShareDataWith(in_slice);
            col_matrix.ShareDataWith(col);
            col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
362
          } else if (data_dim == 2U) {
Q
QI JUN 已提交
363
            im2col(dev_ctx, in_slice, dilations, strides,
C
chengduoZH 已提交
364 365 366
                   std::vector<int>{paddings[0], paddings[1], paddings[0],
                                    paddings[1]},
                   &col);
C
chengduoZH 已提交
367
          } else if (data_dim == 3U) {
Q
QI JUN 已提交
368
            vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
C
chengduoZH 已提交
369
          }
C
chengduoZH 已提交
370 371 372 373

          // gemm
          Tensor filter_grad_slice =
              filter_grad_.Slice(g * out_step, (g + 1) * out_step);
C
chengduoZH 已提交
374 375
          blas.MatMul(out_grad_slice, false, col_matrix, true, T(1.0),
                      &filter_grad_slice, T(1.0));
C
chengduoZH 已提交
376 377 378 379 380
        }
      }
    }
  }
};
Z
zlx 已提交
381 382 383 384 385 386 387 388 389 390

template <typename DeviceContext, typename T>
class DepthwiseConvKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

X
xzl 已提交
391 392 393
    PADDLE_ENFORCE_EQ(
        output->dims()[1] % input->dims()[1], 0,
        "The output channels must be a multiple of the input channels");
Z
zlx 已提交
394 395 396 397 398 399 400
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");

    math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;

    auto& dev_ctx = context.template device_context<DeviceContext>();
401 402
    depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
                  output);
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");

    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();

    math::DepthwiseConvInputGradFunctor<DeviceContext, T>
        depthwiseConvInputGrad;
    math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
        depthwiseConvFilterGrad;

    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, input_grad, static_cast<T>(0));
      depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
437
                             paddings, dilations, input_grad);
438 439 440 441 442 443
    }

    if (filter_grad) {
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));
      depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, paddings,
444
                              dilations, filter_grad);
445
    }
Z
zlx 已提交
446 447 448
  }
};

449 450
}  // namespace operators
}  // namespace paddle