“f8d4e756b43d39151601fd3d4fac7f029f403504”上不存在“python/git@gitcode.net:RobotFutures/Paddle.git”
conv_mkldnn_op.cc 42.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include <unordered_map>
Y
Yu Yang 已提交
16 17
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/memory/malloc.h"
18
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
19
#include "paddle/fluid/platform/mkldnn_reuse.h"
20 21 22 23

namespace paddle {
namespace operators {

24 25 26 27 28 29 30 31
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;

A
Adam 已提交
32 33
inline void GetWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                         int groups, bool is_conv3d) {
Y
Yihua Xu 已提交
34
  if (groups > 1) {
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    if (is_conv3d) {
      int output = weights_tz[0];
      int input = weights_tz[1];
      int dimension = weights_tz[2];
      int height = weights_tz[3];
      int width = weights_tz[4];
      weights_tz.resize(6);
      weights_tz[0] = groups;
      weights_tz[1] = output / groups;
      weights_tz[2] = input;
      weights_tz[3] = dimension;
      weights_tz[4] = height;
      weights_tz[5] = width;
    } else {
      int output = weights_tz[0];
      int input = weights_tz[1];
      int height = weights_tz[2];
      int width = weights_tz[3];
      weights_tz.resize(5);
      weights_tz[0] = groups;
      weights_tz[1] = output / groups;
      weights_tz[2] = input;
      weights_tz[3] = height;
      weights_tz[4] = width;
    }
Y
Yihua Xu 已提交
60 61 62
  }
}

63 64
inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format,
                                           int groups, bool is_conv3d) {
Y
Yihua Xu 已提交
65
  if (is_conv3d) {
66
    return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw;
Y
Yihua Xu 已提交
67
  } else {
68
    return (groups == 1) ? format : MKLDNNMemoryFormat::goihw;
Y
Yihua Xu 已提交
69 70 71
  }
}

72 73
static mkldnn::memory::data_type GetDstType(bool is_int8,
                                            bool force_fp32_output,
74
                                            std::string fuse_activation,
75 76 77
                                            bool fuse_residual_conn,
                                            const Tensor* residual_param) {
  auto dst_dt = mkldnn::memory::data_type::f32;  // uint8_t, int8_t, float
78 79 80 81 82 83 84
  if (is_int8) {
    dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
                 ? mkldnn::memory::data_type::u8
                 : mkldnn::memory::data_type::s8;
    if (force_fp32_output) {
      dst_dt = mkldnn::memory::data_type::f32;
    }
85 86
    if (fuse_residual_conn && residual_param) {
      auto residual_dt = framework::ToMKLDNNDataType(residual_param->type());
87
      if (dst_dt != residual_dt) dst_dt = residual_dt;
88 89 90 91 92
    }
  }
  return dst_dt;
}

93
template <typename T, typename K>
94
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
95 96 97 98
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");
99 100 101 102 103
    bool is_INT8 =
        std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
    if (!is_INT8) {
      ComputeFP32(ctx);
    } else {
104
      std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
105 106 107
      bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
      bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
      auto residual_param = ctx.Input<Tensor>("ResidualData");
108
      auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
109 110 111 112 113 114 115 116
                               fuse_residual_conn, residual_param);
      if (dst_dt == mkldnn::memory::data_type::f32) {
        ComputeINT8<float>(ctx);
      } else if (dst_dt == mkldnn::memory::data_type::u8) {
        ComputeINT8<uint8_t>(ctx);
      } else if (dst_dt == mkldnn::memory::data_type::s8) {
        ComputeINT8<int8_t>(ctx);
      }
117 118
    }
  }
119

120
  void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const {
K
Krzysztof Binias 已提交
121 122
    const bool is_test = ctx.Attr<bool>("is_test");

123 124
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
125 126 127 128
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
129
    auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
130 131
    auto* output = ctx.Output<Tensor>("Output");

132
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
133 134 135 136 137 138
                      platform::errors::InvalidArgument(
                          "The input tensor's layout should be %d, but got %d.",
                          DataLayout::kMKLDNN, input->layout()));
    PADDLE_ENFORCE_NE(
        input->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for Input tensor"));
139

140 141 142 143 144
    PADDLE_ENFORCE_EQ(
        filter->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument(
            "The Filter tensor's layout should be %d, but got %d.",
            DataLayout::kMKLDNN, filter->layout()));
A
Adam 已提交
145
    PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
                      "Wrong format set for Filter tensor");

    PADDLE_ENFORCE_GE(
        input->dims().size(), 4,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
    PADDLE_ENFORCE_LE(
        input->dims().size(), 5,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");

    PADDLE_ENFORCE_GE(
        filter->dims().size(), 4,
        "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
    PADDLE_ENFORCE_LE(
        filter->dims().size(), 5,
        "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");

162
    if (bias) {
163 164 165 166 167
      PADDLE_ENFORCE_EQ(
          bias->layout(), DataLayout::kMKLDNN,
          platform::errors::InvalidArgument(
              "The Bias tensor's layout should be %d, but got %d.",
              DataLayout::kMKLDNN, bias->layout()));
A
Adam 已提交
168
      PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
169 170 171 172
                        "Wrong format set for Bias tensor");

      PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
                        "Bias must only have 1 dimension, i.e. X");
173
    }
174

A
Adam 已提交
175 176 177 178 179 180 181 182 183
    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

    std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
    std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp));

184 185 186
    std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
    float fuse_alpha = ctx.Attr<float>("fuse_alpha");
    float fuse_beta = ctx.Attr<float>("fuse_beta");
187
    bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
188
    int groups = ctx.Attr<int>("groups");
189
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
190
    bool is_conv3d = strides.size() == 3U;
191

192 193 194 195 196 197
    auto input_dims = input->dims();
    auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
    auto filter_dims = filter->dims();
    auto filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());

A
Adam 已提交
198
    auto ksize = framework::vectorize(filter_data_dims);
199 200 201 202

    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             data_dims, strides, ksize);

A
Adam 已提交
203 204
    std::vector<primitive> pipeline;

205
    PADDLE_ENFORCE(
206 207 208 209
        is_conv3d
            ? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
                  dilations[2] == 1
            : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
210 211 212 213 214
        "dilation in convolution is not implemented yet");

    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();

A
Adam 已提交
215 216
    auto src_tz = paddle::framework::vectorize(input->dims());
    auto weights_tz = paddle::framework::vectorize(filter->dims());
217
    int g = std::max(groups, 1);
A
Adam 已提交
218

219
    GetWeightsTz(weights_tz, g, is_conv3d);
A
Adam 已提交
220 221

    auto dst_tz = paddle::framework::vectorize(output->dims());
222

223
    // Get unique name for storing MKLDNN primitives
224
    const std::string key = platform::CreateKey(
H
hong 已提交
225
        src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
226

227
    auto src_format = input->format();
228
    MKLDNNMemoryFormat weights_format =
229 230 231 232 233 234
        GetWeightsFormat(filter->format(), g, is_conv3d);

    auto user_src_md = platform::MKLDNNMemDesc(
        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
    auto user_weights_md = platform::MKLDNNMemDesc(
        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
235 236 237 238 239

    /* create memory descriptor for convolution without specified format
     * ('any') which lets a primitive (convolution in this case) choose
     * the memory format preferred for best performance
     */
240 241 242 243
    // TODO(jczaja): This is workaround to make grad op UT's numerical
    // gradient computation proper as this op is called directly without
    // fetch op following it , so numercial grad is computed (in python)
    // using block formats which will give wrong results
244 245
    std::string data_format = ctx.Attr<std::string>("data_format");
    auto chosen_memory_format =
246 247
        is_test ? MKLDNNMemoryFormat::any
                : platform::data_format_to_memory_format(data_format);
248

249
    weights_format = MKLDNNMemoryFormat::any;
250
    // Check the format for user's special output
251
    if (chosen_memory_format != MKLDNNMemoryFormat::any) {
252 253 254 255
      if (is_conv3d) {
        chosen_memory_format =
            platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
      }
256 257
    }

258
    auto src_md = platform::MKLDNNMemDesc(
259
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
260
    auto weights_md = platform::MKLDNNMemDesc(
261
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
A
Adam 已提交
262
    std::vector<int64_t> bias_tz;
263
    auto dst_md = platform::MKLDNNMemDesc(
264
        dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
265

266 267
    platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

268
    // create a conv primitive descriptor and save it for usage in backward
269
    std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
270 271
    auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
                                 : mkldnn::prop_kind::forward_training;
272
    if (bias) {
A
Adam 已提交
273
      bias_tz = paddle::framework::vectorize(bias->dims());
274
      auto bias_md = platform::MKLDNNMemDesc(
275
          bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
276
      conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
277
          src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
278
          fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
279
          fwd_prop_kind);
280
    } else {
281 282
      conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
          src_md, weights_md, boost::none, dst_md, strides, paddings,
283 284
          mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
          fuse_residual_conn, fwd_prop_kind);
285
    }
286

287
    // create mkldnn memory from input tensors (data/weights)
288 289
    auto user_src_memory_p =
        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
290
    auto user_weights_memory_p = handler.AcquireWeightsMemory(
291
        user_weights_md, to_void_cast<T>(filter_data));
292

293 294 295 296 297
    // create reorder primitive if the input format is not the preferred one
    auto src_memory_p =
        handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
    auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
        user_weights_memory_p, pipeline, is_test);
298

299
    std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p;
300

301
    if (fuse_residual_conn) {
302 303
      auto residual_param = ctx.Input<Tensor>("ResidualData");
      auto residual_param_data = residual_param->data<T>();
304

305 306
      PADDLE_ENFORCE_NE(
          residual_param_data, nullptr,
307 308 309 310
          "Provide data if you want MKLDNN conv+elementwise_add fusion");
      PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
                        "Output and elementwise parameter need to have the "
                        "same dimension sizes");
311

312
      if (residual_param->format() != handler.GetDstFormat()) {
313 314
        auto output_data =
            output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
315
        auto residual_data_tz =
A
Adam 已提交
316
            paddle::framework::vectorize(residual_param->dims());
317 318 319 320 321
        auto residual_data_type =
            paddle::framework::ToMKLDNNDataType(residual_param->type());

        auto user_residual_md = platform::MKLDNNMemDesc(
            residual_data_tz, residual_data_type, residual_param->format());
322
        user_residual_memory_p = handler.AcquireResidualDataMemory(
323
            user_residual_md, to_void_cast<T>(residual_param_data));
324 325 326

        dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
            user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
327
      } else {
328 329 330 331 332
        // Changing ShareDataWith to TensorCopy results in performance drop
        // on ResNet architectures
        // (https://github.com/PaddlePaddle/Paddle/issues/22964)
        output->ShareDataWith(*residual_param);
        auto output_data = output->mutable_data<T>(ctx.GetPlace());
333 334
        dst_memory_p =
            handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
335
      }
336
    } else {
337 338
      auto output_data =
          output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
339 340
      dst_memory_p =
          handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
341
    }
342

A
Adam 已提交
343 344 345
    auto conv_p = handler.AcquireConvolution();

    mkldnn::stream astream(mkldnn_engine);
346 347 348
    if (bias) {
      const T* bias_data = bias->data<T>();
      auto user_bias_md = platform::MKLDNNMemDesc(
349
          {bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
A
Adam 已提交
350
      auto user_bias_memory_p =
351 352
          handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));

A
Adam 已提交
353
      auto bias_memory_p =
354
          handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
A
Adam 已提交
355 356 357 358 359 360

      conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
                                {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
                                {MKLDNN_ARG_BIAS, *bias_memory_p},
                                {MKLDNN_ARG_DST, *dst_memory_p}});

361
    } else {
A
Adam 已提交
362 363 364
      conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
                                {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
                                {MKLDNN_ARG_DST, *dst_memory_p}});
365
    }
A
Adam 已提交
366
    astream.wait();
367

368 369
    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(GetMKLDNNFormat(*dst_memory_p));
370
  }
371
  template <typename T_out>
372 373 374 375 376 377 378 379 380 381
  void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
    const bool is_test = ctx.Attr<bool>("is_test");

    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    auto* input = ctx.Input<Tensor>("Input");
    auto* output = ctx.Output<Tensor>("Output");

382
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
383 384 385
                      platform::errors::InvalidArgument(
                          "The input tensor's layout should be %d, but got %d.",
                          DataLayout::kMKLDNN, input->layout()));
A
Adam 已提交
386
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
387 388 389 390 391 392 393 394 395
                      "Wrong format set for Input tensor");

    PADDLE_ENFORCE_GE(
        input->dims().size(), 4,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
    PADDLE_ENFORCE_LE(
        input->dims().size(), 5,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");

396
    std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
X
xiaolil1 已提交
397
    bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
398 399
    bool unsigned_output =
        (fuse_activation == "relu" || fuse_activation == "relu6");
400

401 402
    const T* input_data = input->data<T>();

A
Adam 已提交
403
    auto src_tz = paddle::framework::vectorize(input->dims());
404

X
xiaolil1 已提交
405 406
    mkldnn::memory::data_type src_dt =
        paddle::framework::ToMKLDNNDataType(input->type());
407

L
lidanqing 已提交
408
    std::string key = platform::CreateKey(
H
hong 已提交
409
        src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter"));
410

411 412
    const std::string key_conv_pd = key + "@conv_pd";
    bool need_s8_to_u8 = false;
413 414 415
    std::shared_ptr<mkldnn::convolution_forward> conv_p;
    std::shared_ptr<mkldnn::memory> src_memory_p;
    std::shared_ptr<mkldnn::memory> user_src_memory_p;
416
    std::shared_ptr<mkldnn::memory> dst_memory_p;
417
    std::vector<primitive> pipeline;
418
    std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
419 420 421 422 423 424 425 426 427
    std::shared_ptr<platform::ConvMKLDNNHandler> handler;

    // This is workaround for hacky implementation
    // of conv int8 mkl-dnn. Once conv fp32 and conv int8
    // are merged/unified, this will disappear
    std::string key_tid = "";
    if (platform::get_cur_mkldnn_session_id() ==
        platform::kMKLDNNSessionID_Default) {
      key_tid = "-t:" + platform::ThreadIDasStr();
L
lidanqing 已提交
428
    }
429

430 431 432
    auto prim_key = key + key_tid + "@conv_p";
    auto dst_key = key + key_tid + "@dst_mem_p";
    auto src_key = key + key_tid + "@src_mem_p";
A
Adam 已提交
433 434
    auto weights_key = key + key_tid + "@weights_mem_p";
    auto bias_key = key + key_tid + "@bias_mem_p";
435
    auto user_src_key = key + key_tid + "@user_src_mem_p";
A
Adam 已提交
436
    auto user_residual_key = key + key_tid + "@user_residual_data_mem_p";
437 438 439 440 441 442
    auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
    auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";

    conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
        dev_ctx.GetBlob(prim_key));

A
Adam 已提交
443 444
    mkldnn::stream astream(mkldnn_engine);

445
    if (conv_p == nullptr || !is_test) {
446 447 448 449 450 451
      float fuse_alpha = ctx.Attr<float>("fuse_alpha");
      float fuse_beta = ctx.Attr<float>("fuse_beta");
      bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");

      auto* filter = ctx.Input<Tensor>("Filter");

452 453 454 455 456
      PADDLE_ENFORCE_EQ(
          filter->layout(), DataLayout::kMKLDNN,
          platform::errors::InvalidArgument(
              "The filter tensor's layout should be %d, but got %d.",
              DataLayout::kMKLDNN, filter->layout()));
A
Adam 已提交
457
      PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
                        "Wrong format set for Filter tensor");

      PADDLE_ENFORCE_GE(
          filter->dims().size(), 4,
          "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
      PADDLE_ENFORCE_LE(
          filter->dims().size(), 5,
          "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");

      PADDLE_ENFORCE_EQ(
          !fuse_residual_conn || !force_fp32_output, true,
          "residual fusion does not support force output with fp32");

      auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;

      if (bias) {
474 475 476 477 478
        PADDLE_ENFORCE_EQ(
            bias->layout(), DataLayout::kMKLDNN,
            platform::errors::InvalidArgument(
                "The bias tensor's layout should be %d, but got %d.",
                DataLayout::kMKLDNN, bias->layout()));
A
Adam 已提交
479
        PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
480 481 482 483 484 485
                          "Wrong format set for Bias tensor");

        PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
                          "Bias must only have 1 dimension, i.e. X");
      }

A
Adam 已提交
486 487 488 489 490 491 492 493 494 495
      std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
      std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

      std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
      std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

      std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
      std::vector<int64_t> dilations(begin(dilations_temp),
                                     end(dilations_temp));

496 497
      std::string padding_algorithm =
          ctx.Attr<std::string>("padding_algorithm");
498 499 500 501 502 503

      bool is_conv3d = strides.size() == 3U;

      PADDLE_ENFORCE_NE(is_conv3d, true,
                        "int8 does not support conv3d currently");

504 505 506 507 508 509
      auto input_dims = input->dims();
      auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
      auto filter_dims = filter->dims();
      auto filter_data_dims =
          framework::slice_ddim(filter_dims, 2, filter_dims.size());

A
Adam 已提交
510
      auto ksize = framework::vectorize(filter_data_dims);
511 512 513 514

      UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                               data_dims, strides, ksize);

515
      int groups = ctx.Attr<int>("groups");
A
Adam 已提交
516
      auto weights_tz = paddle::framework::vectorize(filter->dims());
517 518 519
      int g = std::max(groups, 1);

      GetWeightsTz(weights_tz, g, is_conv3d);
A
Adam 已提交
520
      auto dst_tz = paddle::framework::vectorize(output->dims());
521 522 523 524 525 526 527 528

      PADDLE_ENFORCE_EQ(
          is_conv3d
              ? dilations.size() == 3 && dilations[0] == 1 &&
                    dilations[1] == 1 && dilations[2] == 1
              : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
          true, "dilation in convolution is not implemented yet");

529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
      const K* filter_data = filter->data<K>();
      auto scale_in_data = ctx.Attr<float>("Scale_in");
      auto scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
      auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
      auto scale_out_data =
          force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
      float sum_scale =
          fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;

      bool is_multi_channel = scale_weights_data.size() > 1;

      int count = is_multi_channel ? (g > 1 ? (weights_tz)[1] * (weights_tz)[0]
                                            : (weights_tz)[0])
                                   : 1;
      std::vector<float> output_shift_scale(count);
#pragma omp parallel for if (count > 1)
      for (int i = 0; i < count; i++) {
        if (scale_weights_data[i] == 0.0)
          output_shift_scale[i] =
              scale_out_data;  // weights data will contain 0
                               // in some models, then weights
                               // scale couldn't be calculated
        else
          output_shift_scale[i] =
              static_cast<float>(static_cast<double>(scale_out_data) /
                                 (static_cast<double>(scale_in_data) *
                                  static_cast<double>(scale_weights_data[i])));
      }
L
lidanqing 已提交
557

558 559 560 561 562 563 564 565 566 567
      auto user_src_md =
          platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
      auto user_weights_md = platform::MKLDNNMemDesc(
          {weights_tz}, platform::MKLDNNGetDataType<K>(),
          ((g) == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw);

      /* create memory descriptor for convolution without specified format
      * ('any') which lets a primitive (convolution in this case) choose
      * the memory format preferred for best performance
      */
568
      auto chosen_memory_format = MKLDNNMemoryFormat::any;
569

A
Adam 已提交
570
      std::vector<int64_t> bias_tz;
571 572 573 574 575 576 577 578 579 580 581 582 583

      auto src_md =
          platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
      auto weights_md = platform::MKLDNNMemDesc(
          weights_tz, memory::data_type::s8, chosen_memory_format);
      auto dst_md = platform::MKLDNNMemDesc(
          dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);

      handler.reset(
          new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
      // create a conv primitive descriptor and save it for usage in backward
      auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
                                 : mkldnn::prop_kind::forward_training;
L
lidanqing 已提交
584

585
      if (bias) {
A
Adam 已提交
586
        bias_tz = paddle::framework::vectorize(bias->dims());
587 588 589 590 591 592 593 594 595 596 597 598
        auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
                                               MKLDNNMemoryFormat::x);
        conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
            src_md, weights_md, bias_md, dst_md, strides, paddings,
            mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
            fuse_residual_conn, propagation, output_shift_scale, sum_scale);
      } else {
        conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
            src_md, weights_md, boost::none, dst_md, strides, paddings,
            mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
            fuse_residual_conn, propagation, output_shift_scale, sum_scale);
      }
L
lidanqing 已提交
599

600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
      // create mkldnn memory from input tensors (data/weights)
      user_src_memory_p =
          handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
      auto user_weights_memory_p = handler->AcquireWeightsMemory(
          user_weights_md, to_void_cast<K>(filter_data));

      // create reorder primitive if the input format is not the preferred one
      src_memory_p =
          handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);

      std::shared_ptr<mkldnn::memory> weights_memory_p;
      int mask_reorder =
          is_multi_channel ? ((g != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
      weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive(
          user_weights_memory_p, pipeline, is_test, true, scale_weights_data,
          mask_reorder);

      if (fuse_residual_conn) {
        auto residual_param = ctx.Input<Tensor>("ResidualData");
        PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
                          "Output and elementwise parameter need to have the "
                          "same dimension sizes");
        auto residual_dt =
            paddle::framework::ToMKLDNNDataType(residual_param->type());
        if (residual_param->format() != handler->GetDstFormat()) {
          auto residual_data_tz =
A
Adam 已提交
626
              paddle::framework::vectorize(residual_param->dims());
627 628 629 630 631 632
          auto user_residual_md = platform::MKLDNNMemDesc(
              residual_data_tz, residual_dt, residual_param->format());
          dst_memory_p = platform::SetDstMemory<T_out>(
              ctx, output, residual_param, user_residual_md, handler,
              &pipeline);
        } else {
633
          output->ShareDataWith(*residual_param);
634 635 636 637 638 639 640 641
          dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
        }
        need_s8_to_u8 =
            (platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
            unsigned_output;
      } else {
        dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
      }
L
lidanqing 已提交
642

643 644
      // create convolution op primitive
      auto scale_bias_key = key + "@scale_bias";
A
Adam 已提交
645
      conv_p = handler->AcquireConvolution();
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665
      if (bias) {
        const K* bias_data = bias->data<K>();
        auto user_bias_md = platform::MKLDNNMemDesc(
            {bias_tz}, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::x);
        auto user_bias_memory_p = handler->AcquireBiasMemory(
            user_bias_md, to_void_cast<K>(bias_data));
        std::shared_ptr<mkldnn::memory> bias_memory_p;
        int mask_reorder = is_multi_channel ? 1 << 0 : 1;
        int count =
            is_multi_channel
                ? (g > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
                : 1;
        std::vector<float> scale_bias_data(count);
#pragma omp parallel for if (count > 1)
        for (int i = 0; i < count; i++) {
          scale_bias_data[i] = scale_in_data * scale_weights_data[i];
        }
        bias_memory_p = handler->AcquireBiasMemoryFromPrimitive(
            user_bias_memory_p, pipeline, is_test, true, scale_bias_data,
            mask_reorder);
A
Adam 已提交
666 667 668 669
        conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
                                  {MKLDNN_ARG_BIAS, *bias_memory_p},
                                  {MKLDNN_ARG_DST, *dst_memory_p}});
670
      } else {
A
Adam 已提交
671 672 673
        conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
                                  {MKLDNN_ARG_DST, *dst_memory_p}});
674 675
      }
    } else {
A
Adam 已提交
676
      auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
677 678 679 680 681 682 683
          dev_ctx.GetBlob(src_reorder_key));
      src_memory_p =
          std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
      if (src_memory_reorder_p) {
        user_src_memory_p = std::static_pointer_cast<mkldnn::memory>(
            dev_ctx.GetBlob(user_src_key));
        user_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
A
Adam 已提交
684 685 686
        src_memory_reorder_p->execute(astream, *user_src_memory_p,
                                      *src_memory_p);
        astream.wait();
687 688 689
      } else if (src_memory_p) {
        src_memory_p->set_data_handle(to_void_cast<T>(input_data));
      }
A
Adam 已提交
690 691
      auto weights_memory_p = std::static_pointer_cast<mkldnn::memory>(
          dev_ctx.GetBlob(weights_key));
692 693 694 695 696 697 698 699 700
      dst_memory_p =
          std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
      conv_pd =
          std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
              dev_ctx.GetBlob(key_conv_pd));
      if (conv_pd) {
        handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
                                                      mkldnn_engine, key));
      }
L
lidanqing 已提交
701

702 703
      if (fuse_residual_conn) {
        auto residual_param = ctx.Input<Tensor>("ResidualData");
704
        output->ShareDataWith(*residual_param);
705 706 707
        need_s8_to_u8 =
            (platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
            unsigned_output;
X
xiaolil1 已提交
708
      }
709
      platform::SetDstMemoryHandler<T_out>(ctx, output, handler, dst_memory_p);
L
lidanqing 已提交
710

A
Adam 已提交
711
      auto residual_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
712 713
          dev_ctx.GetBlob(residual_reorder_key));
      if (residual_reorder_p) {
A
Adam 已提交
714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
        auto user_residual_data_p = std::static_pointer_cast<mkldnn::memory>(
            dev_ctx.GetBlob(user_residual_key));
        residual_reorder_p->execute(astream, *user_residual_data_p,
                                    *dst_memory_p);
        astream.wait();
      }

      auto bias_memory_p =
          std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(bias_key));

      if (bias_memory_p) {
        conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
                                  {MKLDNN_ARG_BIAS, *bias_memory_p},
                                  {MKLDNN_ARG_DST, *dst_memory_p}});
      } else {
        conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
                                  {MKLDNN_ARG_DST, *dst_memory_p}});
733 734
      }
    }
A
Adam 已提交
735
    astream.wait();
736
    if (need_s8_to_u8) {
X
xiaolil1 已提交
737 738
      output->mutable_data<uint8_t>(ctx.GetPlace());
    }
739 740 741
    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(GetMKLDNNFormat(*dst_memory_p));
  }
742 743 744
};

template <typename T>
745
class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
746 747 748 749 750
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");

751 752
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
753 754 755 756 757 758 759 760 761
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    const Tensor* input = ctx.Input<Tensor>("Input");
    const Tensor* filter = ctx.Input<Tensor>("Filter");
    const Tensor* output_grad =
        ctx.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));

762
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
763 764 765
                      platform::errors::InvalidArgument(
                          "The input tensor's layout should be %d, but got %d.",
                          DataLayout::kMKLDNN, input->layout()));
A
Adam 已提交
766
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
767
                      "Wrong format set for Input tensor");
768

769 770 771 772 773
    PADDLE_ENFORCE_EQ(
        filter->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument(
            "The filter tensor's layout should be %d, but got %d.",
            DataLayout::kMKLDNN, filter->layout()));
A
Adam 已提交
774
    PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
775 776
                      "Wrong format set for Filter tensor");

777 778 779 780 781
    PADDLE_ENFORCE_EQ(
        output_grad->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument(
            "The output_grad tensor's layout should be %d, but got %d.",
            DataLayout::kMKLDNN, output_grad->layout()));
A
Adam 已提交
782
    PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::undef,
783 784 785 786
                      "Wrong format set for output_grad tensor");

    PADDLE_ENFORCE_EQ(
        ctx.Attr<bool>("is_test"), false,
787 788
        "is_test attribute should be set to False in training phase.");

789 790
    if (!input_grad && !filter_grad) return;

A
Adam 已提交
791 792 793 794 795 796 797 798 799
    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

    std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
    std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp));

800
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
A
Adam 已提交
801

802
    int groups = ctx.Attr<int>("groups");
803

804
    bool is_conv3d = strides.size() == 3U;
805 806 807 808 809 810
    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();
    const T* output_grad_data = output_grad->data<T>();
    T* input_grad_data = nullptr;
    T* filter_grad_data = nullptr;

811 812 813 814 815 816
    auto input_dims = input->dims();
    auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
    auto filter_dims = filter->dims();
    auto filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());

A
Adam 已提交
817
    auto ksize = framework::vectorize(filter_data_dims);
818 819 820 821

    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             data_dims, strides, ksize);

A
Adam 已提交
822 823 824
    auto src_tz = paddle::framework::vectorize(input->dims());
    auto weights_tz = paddle::framework::vectorize(filter->dims());

825
    int g = std::max(groups, 1);
826
    GetWeightsTz(weights_tz, g, is_conv3d);
A
Adam 已提交
827 828
    auto dst_tz = paddle::framework::vectorize(output_grad->dims());

829
    auto src_format = input->format();
830
    MKLDNNMemoryFormat weights_format =
Y
Yihua Xu 已提交
831
        GetWeightsFormat(filter->format(), g, is_conv3d);
832

833
    // Get an unique name from "argument" name of "input" and "Filter" variable
J
Jacek Czaja 已提交
834
    // as well as attributes of primitive to be created
835
    // This name will be used as key when saving info into device context
836
    const std::string key = platform::CreateKey(
H
hong 已提交
837
        src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
838 839

    const std::string key_conv_pd = key + "@conv_pd";
840
    std::vector<primitive> pipeline;
841

842 843
    // Create user memory descriptors
    auto user_src_md = platform::MKLDNNMemDesc(
844
        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
845
    auto user_weights_md = platform::MKLDNNMemDesc(
846
        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
847 848
    auto user_diff_dst_md = platform::MKLDNNMemDesc(
        {dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
849 850 851 852 853

    /* create memory descriptor for conv backward without specified format
     * ('any') which lets a primitive (conv backward in this case) choose
     * the memory format preferred for best performance
     */
854 855 856 857 858 859 860 861 862

    // TODO(jczaja): Once GRAD NHWC is working then format 'any'
    // should be used exclusively. But till forward pass enforce
    // NCHW for training we need to have NCHW here as well
    // to avoid performance degradation in relu_grad and pool2d_grad
    std::string data_format = ctx.Attr<std::string>("data_format");
    auto chosen_memory_format =
        platform::data_format_to_memory_format(data_format);

863
    weights_format = MKLDNNMemoryFormat::any;
864 865 866 867 868 869 870
    // Check the format for user's special output
    if (chosen_memory_format != MKLDNNMemoryFormat::any) {
      if (is_conv3d) {
        chosen_memory_format =
            platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
      }
    }
871

872
    auto src_md = platform::MKLDNNMemDesc(
873
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
874
    auto diff_src_md = platform::MKLDNNMemDesc(
875
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
876
    auto weights_md = platform::MKLDNNMemDesc(
877
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
878
    auto diff_weights_md = platform::MKLDNNMemDesc(
879
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
880
    auto diff_dst_md = platform::MKLDNNMemDesc(
881
        dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
882
    // Retrieve conv_pd from device context
883 884 885
    auto conv_pd =
        std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
            dev_ctx.GetBlob(key_conv_pd));
886 887
    PADDLE_ENFORCE_NE(conv_pd, nullptr,
                      "Fail to find conv_pd in device context");
888

889 890
    auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);

891 892
    // create backward convolution weights primitive descriptor
    auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
A
Adam 已提交
893 894 895
        mkldnn::algorithm::convolution_direct, src_md, diff_weights_md,
        diff_dst_md, strides, mkldnn_paddings[0], mkldnn_paddings[1]);

896 897 898 899 900 901
    auto conv_bwd_weights_pd =
        std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
            conv_bwd_weights_desc, mkldnn_engine, *conv_pd);

    // create backward convolution data primitive descriptor
    auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
A
Adam 已提交
902 903 904
        mkldnn::algorithm::convolution_direct, diff_src_md, weights_md,
        diff_dst_md, strides, mkldnn_paddings[0], mkldnn_paddings[1]);

905 906 907 908
    auto conv_bwd_data_pd =
        std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
            conv_bwd_data_desc, mkldnn_engine, *conv_pd);

J
Jacek Czaja 已提交
909 910 911
    platform::ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd,
                                        conv_bwd_weights_pd, dev_ctx,
                                        mkldnn_engine, key);
912 913 914 915 916 917 918 919

    // create mkldnn memory from input tensors (data/weights)
    auto user_src_memory_p =
        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
    auto user_weights_memory_p = handler.AcquireWeightsMemory(
        user_weights_md, to_void_cast<T>(filter_data));
    auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
        user_diff_dst_md, to_void_cast<T>(output_grad_data));
A
Adam 已提交
920
    mkldnn::stream astream(mkldnn_engine);
921
    if (filter_grad) {
922 923
      auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
          user_src_memory_p, pipeline);
924

925 926 927 928
      auto diff_dst_memory_4filter_p =
          handler.AcquireDiffDstMemoryFromWeightsPrimitive(
              user_diff_dst_memory_p, pipeline);

929
      const size_t size = handler.GetDiffWeightsMemorySize();
930
      filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
931

932 933 934 935
      auto diff_weights_memory_p =
          handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
              reinterpret_cast<void*>(filter_grad_data));

A
Adam 已提交
936
      auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights();
937

A
Adam 已提交
938 939 940 941 942 943
      // TODO(grygielski) why no bias_diff?
      conv_bwd_weights_p->execute(
          astream, {{MKLDNN_ARG_SRC, *src_memory_p},
                    {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4filter_p},
                    {MKLDNN_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
      astream.wait();
944

945 946
      filter_grad->set_layout(DataLayout::kMKLDNN);
      filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
947 948
    }
    if (input_grad) {
949 950 951 952 953 954 955
      auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
          user_weights_memory_p, pipeline);

      auto diff_dst_memory_4data_p =
          handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
                                                        pipeline);

956
      const size_t size = handler.GetDiffSourceMemorySize();
957
      input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
958

959 960 961
      auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
          reinterpret_cast<void*>(input_grad_data));

A
Adam 已提交
962
      auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData();
963

A
Adam 已提交
964 965 966 967 968
      conv_bwd_data_p->execute(astream,
                               {{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
                                {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4data_p},
                                {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
      astream.wait();
969

970 971
      input_grad->set_layout(DataLayout::kMKLDNN);
      input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
972
    }
X
xiaolil1 已提交
973
  }
974
};
975

976 977 978 979 980
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

X
Xin Pan 已提交
981 982 983
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
984
                                    ops::ConvMKLDNNOpKernel<float, float>);
985 986 987

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, U8,
988
                                    ops::kConvMKLDNNINT8,
989
                                    ops::ConvMKLDNNOpKernel<uint8_t, float>);
990 991 992

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, S8,
993
                                    ops::kConvMKLDNNINT8,
994
                                    ops::ConvMKLDNNOpKernel<int8_t, float>);
X
Xin Pan 已提交
995 996 997 998 999

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNGradOpKernel<float>);
1000 1001 1002 1003

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
1004
                                    ops::ConvMKLDNNOpKernel<float, float>);
1005 1006 1007 1008 1009

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNGradOpKernel<float>);