conv_mkldnn_op.cc 46.9 KB
Newer Older
A
Adam Osewski 已提交
1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

A
Adam Osewski 已提交
15 16
#include <tuple>

17
#include "paddle/fluid/framework/expect.h"
18
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
19
#include "paddle/fluid/platform/cpu_info.h"
A
Adam Osewski 已提交
20
#include "paddle/fluid/platform/mkldnn_helper.h"
J
Jacek Czaja 已提交
21
#include "paddle/fluid/platform/mkldnn_reuse.h"
22 23 24

namespace paddle {
namespace operators {
A
Adam Osewski 已提交
25
namespace {
26

27 28 29
inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
                                           const int groups,
                                           const bool is_conv3d) {
Y
Yihua Xu 已提交
30
  if (is_conv3d) {
31
    return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw;
Y
Yihua Xu 已提交
32
  } else {
33
    return (groups == 1) ? format : MKLDNNMemoryFormat::goihw;
Y
Yihua Xu 已提交
34 35 36
  }
}

37 38 39 40 41 42
static dnnl::memory::data_type GetDstType(bool is_int8, bool is_bfloat16,
                                          bool force_fp32_output,
                                          std::string fuse_activation,
                                          bool fuse_residual_conn,
                                          const Tensor* residual_param) {
  auto dst_dt = dnnl::memory::data_type::f32;
43 44
  if (is_int8) {
    dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
45 46
                 ? dnnl::memory::data_type::u8
                 : dnnl::memory::data_type::s8;
47
    if (force_fp32_output) {
48
      dst_dt = dnnl::memory::data_type::f32;
49
    }
50
    if (fuse_residual_conn && residual_param) {
51 52
      auto residual_dt = framework::ToMKLDNNDataType(
          framework::TransToProtoVarType(residual_param->dtype()));
53
      if (dst_dt != residual_dt) dst_dt = residual_dt;
54
    }
55 56
  } else {
    if (!force_fp32_output && is_bfloat16) {
57
      dst_dt = dnnl::memory::data_type::bf16;
58
      if (fuse_residual_conn && residual_param) {
59 60
        dst_dt = framework::ToMKLDNNDataType(
            framework::TransToProtoVarType(residual_param->dtype()));
61 62
      }
    }
63 64 65 66
  }
  return dst_dt;
}

67
template <typename T, typename K, typename T_out>
68
class ConvMKLDNNHandlerT
69 70 71
    : public platform::MKLDNNHandlerT<T, dnnl::convolution_forward,
                                      dnnl::convolution_backward_data,
                                      dnnl::convolution_backward_weights> {
72
 public:
A
Adam Osewski 已提交
73
  ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx,
74
                     const platform::MKLDNNDeviceContext& dev_ctx,
75
                     const dnnl::engine mkldnn_engine,
76 77 78
                     platform::Place cpu_place, const Tensor* input,
                     const Tensor* filter, const Tensor* bias, Tensor* output,
                     const std::string& unique_name)
79 80 81
      : platform::MKLDNNHandlerT<T, dnnl::convolution_forward,
                                 dnnl::convolution_backward_data,
                                 dnnl::convolution_backward_weights>(
82
            dev_ctx, mkldnn_engine, cpu_place,
83
            platform::CreateKey(dev_ctx, phi::vectorize(input->dims()),
84
                                unique_name)) {
85
    if (unlikely(!this->isCached())) {
86
      PADDLE_ENFORCE_EQ(
A
Adam Osewski 已提交
87
          input->layout(), framework::DataLayout::kMKLDNN,
88 89
          platform::errors::InvalidArgument(
              "The input tensor's layout should be %d, but got %d.",
A
Adam Osewski 已提交
90
              framework::DataLayout::kMKLDNN, input->layout()));
91 92 93
      PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
                        platform::errors::InvalidArgument(
                            "Wrong format set for Input tensor"));
94

95
      PADDLE_ENFORCE_EQ(
A
Adam Osewski 已提交
96
          filter->layout(), framework::DataLayout::kMKLDNN,
97 98
          platform::errors::InvalidArgument(
              "The Filter tensor's layout should be %d, but got %d.",
A
Adam Osewski 已提交
99
              framework::DataLayout::kMKLDNN, filter->layout()));
100 101 102
      PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
                        platform::errors::InvalidArgument(
                            "Wrong format set for Filter tensor"));
K
Krzysztof Binias 已提交
103

104 105 106 107 108 109 110 111 112 113 114 115
      PADDLE_ENFORCE_GE(
          input->dims().size(), 4,
          platform::errors::InvalidArgument(
              "Input must be with 4 or 5 dimensions, i.e. NCHW or "
              "NCDHW, but got dimension = %d .",
              input->dims().size()));
      PADDLE_ENFORCE_LE(
          input->dims().size(), 5,
          platform::errors::InvalidArgument(
              "Input must be with 4 or 5 dimensions, i.e. NCHW or "
              "NCDHW, but got dimension = %d .",
              input->dims().size()));
116

117 118 119 120 121 122 123 124 125 126 127 128
      PADDLE_ENFORCE_GE(
          filter->dims().size(), 4,
          platform::errors::InvalidArgument(
              "Filter must be with 4 or 5 dimensions, i.e. OIHW or "
              "OIDHW, but got dimension = %d .",
              filter->dims().size()));
      PADDLE_ENFORCE_LE(
          filter->dims().size(), 5,
          platform::errors::InvalidArgument(
              "Filter must be with 4 or 5 dimensions, i.e. OIHW or "
              "OIDHW, but got dimension = %d .",
              filter->dims().size()));
129

130 131
      if (bias) {
        PADDLE_ENFORCE_EQ(
A
Adam Osewski 已提交
132
            bias->layout(), framework::DataLayout::kMKLDNN,
133 134
            platform::errors::InvalidArgument(
                "The Bias tensor's layout should be %d, but got %d.",
A
Adam Osewski 已提交
135
                framework::DataLayout::kMKLDNN, bias->layout()));
136 137 138
        PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
                          platform::errors::InvalidArgument(
                              "Got wrong format for Bias tensor."));
139

140 141 142 143 144 145
        PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
                          platform::errors::InvalidArgument(
                              "Bias must only have 1 dimension, "
                              "i.e. X, but got dimension = %d .",
                              bias->dims().size()));
      }
F
FDInSky 已提交
146

147 148 149 150 151 152 153 154 155
      const std::string fuse_activation =
          ctx.Attr<std::string>("fuse_activation");
      const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
      const float fuse_beta = ctx.Attr<float>("fuse_beta");
      const bool fuse_residual_conn =
          ctx.Attr<bool>("fuse_residual_connection");
      const int groups = ctx.Attr<int>("groups");
      const std::string padding_algorithm =
          ctx.Attr<std::string>("padding_algorithm");
F
FDInSky 已提交
156

157
      const auto input_dims = input->dims();
158
      const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
159 160
      const auto filter_dims = filter->dims();
      const auto filter_data_dims =
161
          phi::slice_ddim(filter_dims, 2, filter_dims.size());
162

163
      const auto ksize = phi::vectorize(filter_data_dims);
164
      const bool is_test = ctx.Attr<bool>("is_test");
165

166 167
      auto strides_temp = ctx.Attr<std::vector<int>>("strides");
      std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
168

169 170
      auto paddings_temp = ctx.Attr<std::vector<int>>("paddings");
      std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
A
Adam 已提交
171

172 173 174
      auto dilations_temp = ctx.Attr<std::vector<int>>("dilations");
      std::vector<int64_t> dilations(begin(dilations_temp),
                                     end(dilations_temp));
A
Adam 已提交
175

176 177
      UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                               data_dims, strides, ksize);
A
Adam 已提交
178

179 180
      std::transform(dilations.begin(), dilations.end(), dilations.begin(),
                     [](int64_t i) { return i - 1; });
181

182
      const auto src_tz = phi::vectorize(input->dims());
183

184
      auto weights_tz = phi::vectorize(filter->dims());
185
      platform::GetGroupConvWeightsTz(weights_tz, groups);
186

187
      const auto dst_tz = phi::vectorize(output->dims());
188

189
      const dnnl::memory::dims stride_dims = strides;
190
      const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
191
      const dnnl::memory::dims dilations_dims = dilations;
A
Adam 已提交
192

193 194 195 196
      /* create memory descriptor for convolution without specified format
       * ('any') which lets a primitive (convolution in this case) choose
       * the memory format preferred for best performance
       */
197
      auto chosen_memory_format = MKLDNNMemoryFormat::any;
198
      auto data_type = dnnl::memory::data_type::f32;
199 200
      if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
          std::is_same<T_out, platform::bfloat16>::value)
201
        data_type = dnnl::memory::data_type::bf16;
202

203
      dnnl::memory::desc src_md, weights_md;
A
Adam Osewski 已提交
204 205
      if (platform::is_int8<T>()) {
        src_md = platform::MKLDNNMemDesc(
206 207
            src_tz, framework::ToMKLDNNDataType(
                        framework::TransToProtoVarType(input->dtype())),
A
Adam Osewski 已提交
208 209
            chosen_memory_format);
        weights_md = platform::MKLDNNMemDesc(
210
            weights_tz, dnnl::memory::data_type::s8, chosen_memory_format);
A
Adam Osewski 已提交
211 212 213 214 215 216 217
      } else {
        src_md =
            platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
        weights_md = platform::MKLDNNMemDesc(weights_tz, data_type,
                                             MKLDNNMemoryFormat::any);
      }

218
      const auto dst_md = platform::MKLDNNMemDesc(
219
          dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
220 221
      const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
                                         : dnnl::prop_kind::forward_training;
222

J
jakpiase 已提交
223
      float sum_scale = 1.0f;
224
      float activation_scale = 1.0f;
A
Adam Osewski 已提交
225
      std::vector<float> output_shift_scale;
J
jakpiase 已提交
226
      if (platform::is_int8<T>())
227 228
        std::tie(sum_scale, output_shift_scale, activation_scale) =
            get_int8_scales(ctx);
A
Adam Osewski 已提交
229

230
      const dnnl::primitive_attr conv_attr = CreatePostOps(
A
Adam Osewski 已提交
231
          fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
232
          output_shift_scale, sum_scale, activation_scale);  // for INT8 only!
A
Adam 已提交
233

234
      if (bias) {
235
        auto bias_tz = phi::vectorize(bias->dims());
236
        dnnl::memory::desc bias_md;
A
Adam Osewski 已提交
237 238
        if (platform::is_int8<T>()) {
          bias_md = platform::MKLDNNMemDesc(
239
              bias_tz, dnnl::memory::data_type::s32, MKLDNNMemoryFormat::x);
A
Adam Osewski 已提交
240 241 242 243
        } else {
          bias_md = platform::MKLDNNMemDesc(bias_tz, data_type,
                                            MKLDNNMemoryFormat::x);
        }
244

245
        this->AcquireForwardPrimitiveDescriptor(
246
            conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
247
            src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims,
248 249
            mkldnn_paddings[0], mkldnn_paddings[1]);
      } else {
250
        this->AcquireForwardPrimitiveDescriptor(
251
            conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
252 253
            src_md, weights_md, dst_md, stride_dims, dilations_dims,
            mkldnn_paddings[0], mkldnn_paddings[1]);
254 255 256
      }
    }
  }
257

258 259 260 261 262 263
  ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx,
                     const platform::MKLDNNDeviceContext& dev_ctx,
                     platform::Place cpu_place, const Tensor* in,
                     const Tensor* filter, const Tensor* bias,
                     const Tensor* out_grad, Tensor* filter_grad,
                     Tensor* in_x_grad, const std::string& unique_name)
264 265 266
      : platform::MKLDNNHandlerT<T, dnnl::convolution_forward,
                                 dnnl::convolution_backward_data,
                                 dnnl::convolution_backward_weights>(
267
            dev_ctx, dev_ctx.GetEngine(), cpu_place,
268
            platform::CreateKey(dev_ctx, phi::vectorize(in->dims()),
269
                                unique_name)) {
270
    if (unlikely(!this->isBwdCached())) {
271
      PADDLE_ENFORCE_EQ(
A
Adam Osewski 已提交
272
          in->layout(), framework::DataLayout::kMKLDNN,
273 274
          platform::errors::InvalidArgument(
              "The input tensor's layout should be %d, but got %d.",
A
Adam Osewski 已提交
275
              framework::DataLayout::kMKLDNN, in->layout()));
276 277 278 279 280
      PADDLE_ENFORCE_NE(in->format(), MKLDNNMemoryFormat::undef,
                        platform::errors::InvalidArgument(
                            "Got wrong format for Input tensor."));

      PADDLE_ENFORCE_EQ(
A
Adam Osewski 已提交
281
          filter->layout(), framework::DataLayout::kMKLDNN,
282 283
          platform::errors::InvalidArgument(
              "The filter tensor's layout should be %d, but got %d.",
A
Adam Osewski 已提交
284
              framework::DataLayout::kMKLDNN, filter->layout()));
285 286 287 288 289
      PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
                        platform::errors::InvalidArgument(
                            "Got wrong format for Filter tensor."));

      PADDLE_ENFORCE_EQ(
A
Adam Osewski 已提交
290
          out_grad->layout(), framework::DataLayout::kMKLDNN,
291 292
          platform::errors::InvalidArgument(
              "The output_grad tensor's layout should be %d, but got %d.",
A
Adam Osewski 已提交
293
              framework::DataLayout::kMKLDNN, out_grad->layout()));
294 295 296 297 298
      PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
                        platform::errors::InvalidArgument(
                            "Wrong format set for output_grad tensor"));

      PADDLE_ENFORCE_EQ(
299
          ctx.Attr<bool>("is_test"), false,
300 301 302 303 304 305 306 307 308 309 310 311 312 313
          platform::errors::InvalidArgument(
              "is_test attribute should be set to False in training phase."));

      std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
      std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

      std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
      std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

      std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
      std::vector<int64_t> dilations(begin(dilations_temp),
                                     end(dilations_temp));

      auto input_dims = in->dims();
314
      auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
315 316
      auto filter_dims = filter->dims();
      auto filter_data_dims =
317 318
          phi::slice_ddim(filter_dims, 2, filter_dims.size());
      auto ksize = phi::vectorize(filter_data_dims);
319

A
Adam Osewski 已提交
320 321
      std::string padding_algorithm =
          ctx.Attr<std::string>("padding_algorithm");
322 323 324
      UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                               data_dims, strides, ksize);

325 326
      auto src_tz = phi::vectorize(in->dims());
      auto weights_tz = phi::vectorize(filter->dims());
327

A
Adam Osewski 已提交
328
      int groups = ctx.Attr<int>("groups");
329 330
      int g = std::max(groups, 1);
      platform::GetGroupConvWeightsTz(weights_tz, g);
331
      auto dst_tz = phi::vectorize(out_grad->dims());
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355

      /* create memory descriptor for conv backward without specified format
       * ('any') which lets a primitive (conv backward in this case) choose
       * the memory format preferred for best performance
       */
      const auto chosen_memory_format = MKLDNNMemoryFormat::any;
      const auto weights_format = MKLDNNMemoryFormat::any;

      auto src_md = platform::MKLDNNMemDesc(
          src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
      const auto dst_md = platform::MKLDNNMemDesc(
          dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
      auto diff_src_md = platform::MKLDNNMemDesc(
          src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
      auto weights_md = platform::MKLDNNMemDesc(
          weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
      auto diff_weights_md = platform::MKLDNNMemDesc(
          weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
      auto diff_dst_md = platform::MKLDNNMemDesc(
          dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);

      auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
      std::transform(dilations.begin(), dilations.end(), dilations.begin(),
                     [](int64_t i) { return i - 1; });
356
      const dnnl::memory::dims dilations_dims = dilations;
357

358
      const dnnl::memory::dims stride_dims = strides;
359
      // Recreating FWD PD. For training there are no post ops in convolution
360
      dnnl::primitive_attr conv_attr;
361
      if (bias) {
362
        auto bias_tz = phi::vectorize(bias->dims());
363
        dnnl::memory::desc bias_md;
A
Adam Osewski 已提交
364 365
        if (platform::is_int8<T>()) {
          bias_md = platform::MKLDNNMemDesc(
366
              bias_tz, dnnl::memory::data_type::s32, MKLDNNMemoryFormat::x);
A
Adam Osewski 已提交
367 368
        } else {
          bias_md = platform::MKLDNNMemDesc(
369
              bias_tz, dnnl::memory::data_type::f32, MKLDNNMemoryFormat::x);
A
Adam Osewski 已提交
370
        }
371

372
        this->AcquireForwardPrimitiveDescriptor(
373
            conv_attr, dnnl::prop_kind::forward_training,
374 375 376 377
            dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md,
            dst_md, stride_dims, dilations_dims, mkldnn_paddings[0],
            mkldnn_paddings[1]);
      } else {
378
        this->AcquireForwardPrimitiveDescriptor(
379
            conv_attr, dnnl::prop_kind::forward_training,
380 381 382 383 384
            dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md,
            stride_dims, dilations_dims, mkldnn_paddings[0],
            mkldnn_paddings[1]);
      }

385
      this->AcquireBackwardPrimitiveDescriptor(
386
          dnnl::algorithm::convolution_direct, diff_src_md, weights_md,
387 388 389
          diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
          mkldnn_paddings[1]);

390
      this->AcquireBackwardWeightsPrimitiveDescriptor(
391
          dnnl::algorithm::convolution_direct, src_md, diff_weights_md,
392 393 394 395 396
          diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
          mkldnn_paddings[1]);
    }
  }

397 398 399 400 401 402 403 404 405 406 407 408 409
  std::shared_ptr<std::tuple<float, std::vector<float>>> get_int8_bias_scales(
      const framework::ExecutionContext& ctx) {
    // Get scales int8 bias key
    const std::string key_bs = this->key_ + "@bs";

    // Scales for int8 bias are to be cached to avoid
    // computing them each iteration
    auto bias_scale_tuple =
        std::static_pointer_cast<std::tuple<float, std::vector<float>>>(
            this->dev_ctx_.GetBlob(key_bs));
    if (bias_scale_tuple) return bias_scale_tuple;

    const auto* filter = ctx.Input<Tensor>("Filter");
410
    const auto& weights_tz = phi::vectorize(filter->dims());
411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
    const int groups = std::max(ctx.Attr<int>("groups"), 1);

    const auto& scale_weights_data =
        ctx.Attr<std::vector<float>>("Scale_weights");
    const auto& scale_in_data = ctx.Attr<float>("Scale_in");

    bool is_multi_channel = scale_weights_data.size() > 1;
    int mask_reorder = is_multi_channel ? 1 << 0 : 1;

    int count = 1;
    if (is_multi_channel) {
      count *= weights_tz[0];
      if (groups > 1) {
        count *= weights_tz[1];
      }
    }

    bias_scale_tuple =
        std::make_shared<std::tuple<float, std::vector<float>>>(std::make_tuple(
            static_cast<float>(mask_reorder), std::vector<float>(count)));
    for (int i = 0; i < count; i++) {
      std::get<1>(*bias_scale_tuple)[i] = scale_in_data * scale_weights_data[i];
    }

    this->dev_ctx_.SetBlob(key_bs, bias_scale_tuple);

    return bias_scale_tuple;
  }

440
  std::tuple<float, std::vector<float>, float> get_int8_scales(
A
Adam Osewski 已提交
441 442
      const framework::ExecutionContext& ctx) const {
    const auto* filter = ctx.Input<Tensor>("Filter");
443
    const auto& weights_tz = phi::vectorize(filter->dims());
A
Adam Osewski 已提交
444 445 446 447 448 449 450 451 452

    const bool& force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    const bool& fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
    const int groups = std::max(ctx.Attr<int>("groups"), 1);

    const auto& scale_in_data = ctx.Attr<float>("Scale_in");
    const auto& scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
    auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
    bool is_multi_channel = scale_weights_data.size() > 1;
453 454 455 456
    bool has_activation = !ctx.Attr<std::string>("fuse_activation").empty();
    float activation_scale =
        force_fp32_output ? 1.0f : has_activation ? ctx.Attr<float>("Scale_out")
                                                  : 1.0f;
A
Adam Osewski 已提交
457
    auto scale_out_data =
458 459 460
        force_fp32_output ? 1.0f : has_activation
                                       ? 1.0f
                                       : ctx.Attr<float>("Scale_out");
A
Adam Osewski 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
    float sum_scale =
        fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
    int count =
        is_multi_channel
            ? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
            : 1;
    std::vector<float> output_shift_scale(count);

#pragma omp parallel for if (count > 50)
    for (int i = 0; i < count; i++) {
      if (scale_weights_data[i] == 0.0)
        // weights data will contain 0 in some models, then weights
        // scale couldn't be calculated
        output_shift_scale[i] = scale_out_data;
      else
        output_shift_scale[i] =
            static_cast<float>(static_cast<double>(scale_out_data) /
                               (static_cast<double>(scale_in_data) *
                                static_cast<double>(scale_weights_data[i])));
    }

482
    return std::make_tuple(sum_scale, output_shift_scale, activation_scale);
A
Adam Osewski 已提交
483 484
  }

485
  dnnl::primitive_attr CreatePostOps(
486 487
      std::string fuse_activation, float fuse_alpha, float fuse_beta,
      bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
488
      float sum_scale = 1.0f, float activation_scale = 1.0f) {
489 490
    dnnl::primitive_attr conv_attr;
    dnnl::post_ops post_operations;
491 492 493 494
    if (output_shift_scale.size() > 0) {
      int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
      conv_attr.set_output_scales(mask, output_shift_scale);
    }
495

496 497 498 499 500 501 502 503 504 505 506
    // Fusion with Elementwise layer relies on adding a sum post-operation with
    // the scale parameter. It is assumed that when fuse_residual_connection is
    // true, the output tensor contains the data coming from residual
    // connection. The result of this post_op is:
    // Output = scale * Output + Conv_Out.
    if (fuse_residual_conn) {
      post_operations.append_sum(sum_scale);
    }
    // Fusion with ReLU layer is executed through the PostOps feature. Create a
    // PostOps object and configure it to execute an eltwise relu operation.
    if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
507 508 509
      post_operations.append_eltwise(activation_scale,
                                     dnnl::algorithm::eltwise_relu, fuse_alpha,
                                     fuse_beta);
510
    } else if (fuse_activation == "relu6") {
511 512
      post_operations.append_eltwise(activation_scale,
                                     dnnl::algorithm::eltwise_bounded_relu,
513
                                     fuse_alpha, fuse_beta);
514 515 516 517
    } else if (fuse_activation == "swish") {
      post_operations.append_eltwise(activation_scale,
                                     dnnl::algorithm::eltwise_swish, fuse_alpha,
                                     fuse_beta);
J
jakpiase 已提交
518
    } else if (fuse_activation == "hard_swish") {
519 520
      post_operations.append_eltwise(activation_scale,
                                     dnnl::algorithm::eltwise_hardswish,
521
                                     fuse_alpha, fuse_beta);
522 523 524 525
    } else if (fuse_activation == "mish") {
      post_operations.append_eltwise(activation_scale,
                                     dnnl::algorithm::eltwise_mish, fuse_alpha,
                                     fuse_beta);
526
    } else if (fuse_activation == "hard_sigmoid") {
527 528
      post_operations.append_eltwise(activation_scale,
                                     dnnl::algorithm::eltwise_linear,
529
                                     fuse_alpha, fuse_beta);
530 531
      post_operations.append_eltwise(activation_scale,
                                     dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
B
baoachun 已提交
532
    } else if (fuse_activation == "gelu_tanh") {
533 534
      post_operations.append_eltwise(
          activation_scale, dnnl::algorithm::eltwise_gelu_tanh, 0.0f, 0.0f);
B
baoachun 已提交
535
    } else if (fuse_activation == "gelu_erf") {
536 537
      post_operations.append_eltwise(
          activation_scale, dnnl::algorithm::eltwise_gelu_erf, 0.0f, 0.0f);
538 539 540 541
    }
    conv_attr.set_post_ops(post_operations);
    return conv_attr;
  }
542

543
  std::shared_ptr<dnnl::memory>
544 545 546
  AcquireWeightsMemoryWithReorderFromDataPrimitive(
      const framework::Tensor* filter, const int groups, const bool is_conv3d) {
    const K* filter_data = filter->data<K>();
547
    auto weights_tz = phi::vectorize(filter->dims());
548 549 550 551 552 553 554 555
    platform::GetGroupConvWeightsTz(weights_tz, groups);

    auto user_src_md = platform::MKLDNNMemDesc(
        weights_tz, platform::MKLDNNGetDataType<K>(),
        GetWeightsFormat(filter->format(), groups, is_conv3d));

    return this->AcquireMemoryWithReorder(
        user_src_md, this->bwd_pd_->weights_desc(),
A
Adam Osewski 已提交
556
        platform::to_void_cast<K>(filter_data), "@weights_mem_d_p", false);
557 558
  }

559
  std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
560
      const framework::Tensor* input) {
561 562 563 564
    return this->AcquireMemoryWithReorderPrimitive(
        input, "@src_mem_p_user", "@src_mem_p_target", "@src_mem_p",
        this->fwd_pd_->src_desc());
  }
565

566
  std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorderFromWeightsPrimitive(
567 568 569 570 571 572
      const framework::Tensor* input) {
    return this->AcquireMemoryWithReorderPrimitive(
        input, "@src_mem_w_p_user", "@src_mem_w_p_target", "@src_mem_w_p",
        this->bwd_w_pd_->src_desc());
  }

573
  std::shared_ptr<dnnl::memory>
574 575 576 577 578 579 580
  AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
      const framework::Tensor* out_grad) {
    return this->AcquireMemoryWithReorderPrimitive(
        out_grad, "@diff_dst_mem_w_p_user", "@diff_dst_mem_w_p_target",
        "@diff_dst_mem_w_p", this->bwd_w_pd_->diff_dst_desc());
  }

581
  std::shared_ptr<dnnl::memory>
582 583 584 585 586 587 588
  AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive(
      const framework::Tensor* out_grad) {
    return this->AcquireMemoryWithReorderPrimitive(
        out_grad, "@diff_dst_mem_p_user", "@diff_dst_mem_p_target",
        "@diff_dst_mem_p", this->bwd_pd_->diff_dst_desc());
  }

589
  std::shared_ptr<dnnl::memory> AcquireMemoryWithReorderPrimitive(
590 591
      const framework::Tensor* in_mem, const char* key_mem_user,
      const char* key_mem_target, const char* key_mem,
592
      const dnnl::memory::desc& mem_md) {
593 594 595 596 597 598
    const T* in_mem_data = in_mem->data<T>();
    const std::string user_key_suffix{key_mem_user};
    auto user_mem_p = this->AcquireMemory(user_key_suffix);

    if (!user_mem_p) {
      auto user_mem_md = platform::MKLDNNMemDesc(
599
          phi::vectorize(in_mem->dims()), platform::MKLDNNGetDataType<T>(),
600
          in_mem->format());
601
      return this->AcquireMemoryWithReorder(
602
          user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem);
603
    } else {
604 605
      const std::string target_key_suffix{key_mem_target};
      const auto target_mem_p = this->AcquireMemory(target_key_suffix);
A
Adam Osewski 已提交
606
      user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data));
607
      if (user_mem_p != target_mem_p) {
608
        this->AcquireReorder(user_mem_p, target_mem_p);
609
      }
610
      return target_mem_p;
611
    }
612 613
  }

614
  std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
615
      const framework::Tensor* filter, const int groups, const bool is_conv3d,
616 617
      const bool is_test, const std::vector<float>& scale_data = {1.0f},
      int mask = 0) {
618 619 620
    // This is workaround to make execution faster, delete
    // if statement after including md inside Tensor
    auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
621
    if (is_test && weights_mem_p) {
622
      return weights_mem_p;
623
    } else if (is_test) {
624
      const K* filter_data = filter->data<K>();
625
      auto weights_tz = phi::vectorize(filter->dims());
626
      platform::GetGroupConvWeightsTz(weights_tz, groups);
627 628

      auto user_src_md = platform::MKLDNNMemDesc(
629
          weights_tz, platform::MKLDNNGetDataType<K>(),
630 631 632 633
          GetWeightsFormat(filter->format(), groups, is_conv3d));

      return this->AcquireMemoryWithReorder(
          user_src_md, this->fwd_pd_->weights_desc(),
634 635
          platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, {},
          scale_data, mask);
636 637
    } else {
      const T* filter_data = filter->data<T>();
638
      auto weights_tz = phi::vectorize(filter->dims());
639 640 641 642 643 644 645 646 647 648
      platform::GetGroupConvWeightsTz(weights_tz, groups);

      auto user_src_md = platform::MKLDNNMemDesc(
          weights_tz, platform::MKLDNNGetDataType<T>(),
          GetWeightsFormat(filter->format(), groups, is_conv3d));

      return this->AcquireMemoryWithReorder(
          user_src_md, this->fwd_pd_->weights_desc(),
          platform::to_void_cast<T>(filter_data), "@weights_mem_p", is_test, {},
          scale_data, mask);
649
    }
650
  }
651

652
  std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
653
      const framework::Tensor* bias, const bool is_test,
A
Adam Osewski 已提交
654
      const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
655
    auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
656
    if (is_test && bias_mem_p) {
657 658 659 660
      return bias_mem_p;
    } else {
      const K* bias_data = bias->data<K>();
      auto user_bias_md = platform::MKLDNNMemDesc(
661
          phi::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
662 663 664
          MKLDNNMemoryFormat::x);

      return this->AcquireMemoryWithReorder(
A
Adam Osewski 已提交
665
          user_bias_md, this->fwd_pd_->bias_desc(),
666
          platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test, {},
A
Adam Osewski 已提交
667
          scale_data, mask);
668
    }
669
  }
670

671
  std::shared_ptr<dnnl::memory> AcquireResidualMemory(
672
      const framework::Tensor* residual_param) {
673
    void* residual_data =
674 675
        framework::TransToProtoVarType(residual_param->dtype()) ==
                framework::DataTypeTrait<T_out>::DataType()
A
Adam Osewski 已提交
676 677
            ? platform::to_void_cast<T_out>(residual_param->data<T_out>())
            : platform::to_void_cast<T>(residual_param->data<T>());
678 679 680 681 682 683
    auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p");
    if (residual_mem_p) {
      residual_mem_p->set_data_handle(residual_data);
      return residual_mem_p;
    } else {
      auto user_residual_md = platform::MKLDNNMemDesc(
684
          phi::vectorize(residual_param->dims()),
685 686
          framework::ToMKLDNNDataType(
              framework::TransToProtoVarType(residual_param->dtype())),
687
          residual_param->format());
688

689 690 691
      return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
                                              "@user_residual_data_mem_p");
    }
692 693
  }

694
  std::shared_ptr<dnnl::memory> AcquireDstMemoryWithResidual(
695 696 697 698 699
      framework::Tensor* output, const framework::Tensor* residual_param) {
    std::shared_ptr<dnnl::memory> dst_memory_p;
    if (residual_param->format() !=
        platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
      auto residual_memory_p = this->AcquireResidualMemory(residual_param);
700
      dst_memory_p = this->template AcquireDstMemory<T_out>(output);
701
      this->AcquireReorder(residual_memory_p, dst_memory_p);
702 703 704 705 706
    } else {
      // Changing ShareDataWith to TensorCopy results in performance drop
      // on ResNet architectures
      // (https://github.com/PaddlePaddle/Paddle/issues/22964)
      output->ShareDataWith(*residual_param);
707
      dst_memory_p = this->template AcquireDstMemory<T_out>(output);
708 709 710 711 712
    }
    return dst_memory_p;
  }
};

A
Adam Osewski 已提交
713 714
}  // anonymous namespace

715
template <typename T, typename K>
A
Adam Osewski 已提交
716
class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
717
 public:
A
Adam Osewski 已提交
718
  void Compute(const framework::ExecutionContext& ctx) const override {
719
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
A
Adam Osewski 已提交
720
                      platform::errors::PreconditionNotMet(
721 722 723
                          "Operator DNNL Conv must use CPUPlace"));
    bool is_INT8 =
        std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
724 725 726 727 728 729 730 731
    bool is_BFLOAT16 = ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
    auto residual_param = ctx.Input<Tensor>("ResidualData");
    bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
    std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
    bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    auto dst_dt =
        GetDstType(is_INT8, is_BFLOAT16, force_fp32_output, fuse_activation,
                   fuse_residual_conn, residual_param);
732
    if (!is_INT8) {
733
      if (dst_dt == dnnl::memory::data_type::f32) {
734
        ComputeFP32<float>(ctx);
735
      } else if (dst_dt == dnnl::memory::data_type::bf16) {
736 737
        ComputeFP32<platform::bfloat16>(ctx);
      }
738
    } else {
739
      if (dst_dt == dnnl::memory::data_type::f32) {
740
        ComputeINT8<float>(ctx);
741
      } else if (dst_dt == dnnl::memory::data_type::u8) {
742
        ComputeINT8<uint8_t>(ctx);
743
      } else if (dst_dt == dnnl::memory::data_type::s8) {
744 745
        ComputeINT8<int8_t>(ctx);
      }
746
    }
747
  }
748

749
  template <typename T_out>
A
Adam Osewski 已提交
750
  void ComputeFP32(const framework::ExecutionContext& ctx) const {
751
    auto& dev_ctx =
A
Adam Osewski 已提交
752
        ctx.template device_context<platform::MKLDNNDeviceContext>();
753
    const auto& mkldnn_engine = dev_ctx.GetEngine();
754

755
    const bool is_test = ctx.Attr<bool>("is_test");
756 757
    const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
    const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
758

759 760 761 762 763
    const auto* input = ctx.Input<Tensor>("Input");
    const auto* filter = ctx.Input<Tensor>("Filter");
    const auto* bias =
        ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
    auto* output = ctx.Output<Tensor>("Output");
764

765
    ConvMKLDNNHandlerT<T, K, T_out> handler(
766 767
        ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
        output, ctx.InputName("Input") + ctx.InputName("Filter"));
768

769
    auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
770

771
    auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
772
        filter, ctx.Attr<int>("groups"), is_conv3d, is_test);
773

774 775 776
    std::shared_ptr<dnnl::memory> dst_memory_p;
    if (fuse_residual_conn) {
      auto* residual_param = ctx.Input<Tensor>("ResidualData");
777
      dst_memory_p =
778 779
          handler.AcquireDstMemoryWithResidual(output, residual_param);
    } else {
780
      dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
781
    }
782

783
    auto conv_p = handler.AcquireForwardPrimitive();
A
Adam 已提交
784

785
    std::unordered_map<int, dnnl::memory> args = {
786 787 788
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};
A
Adam 已提交
789

790
    if (bias) {
791
      auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
792
      args.insert({DNNL_ARG_BIAS, *bias_memory_p});
793
    }
794

795
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
796
    conv_p->execute(astream, args);
A
Adam 已提交
797
    astream.wait();
798

A
Adam Osewski 已提交
799 800
    output->set_layout(framework::DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
801
  }
802

803
  template <typename T_out>
A
Adam Osewski 已提交
804
  void ComputeINT8(const framework::ExecutionContext& ctx) const {
805
    auto& dev_ctx =
A
Adam Osewski 已提交
806
        ctx.template device_context<platform::MKLDNNDeviceContext>();
807 808
    const auto& mkldnn_engine = dev_ctx.GetEngine();

A
Adam Osewski 已提交
809 810 811 812 813
    const std::string& fuse_activation =
        ctx.Attr<std::string>("fuse_activation");
    const bool& fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
    const bool& force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
814

815 816
    bool unsigned_output =
        (fuse_activation == "relu" || fuse_activation == "relu6");
817 818
    bool need_s8_to_u8 = false;

A
Adam Osewski 已提交
819 820 821 822 823 824 825 826
    PADDLE_ENFORCE_NE(
        is_conv3d, true,
        platform::errors::Unimplemented(
            "OneDNN int8 convolution does not support 3D inputs currently"));
    PADDLE_ENFORCE_EQ(
        fuse_residual_conn && force_fp32_output, false,
        platform::errors::Unimplemented(
            "residual fusion does not support force output with fp32"));
A
Adam 已提交
827

A
Adam Osewski 已提交
828 829 830 831
    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
    auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
    auto* output = ctx.Output<Tensor>("Output");
832

A
Adam Osewski 已提交
833 834 835
    ConvMKLDNNHandlerT<T, K, T_out> handler(
        ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
        output, ctx.InputName("Input") + ctx.InputName("Filter"));
836

A
Adam Osewski 已提交
837
    auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
F
FDInSky 已提交
838

A
Adam Osewski 已提交
839 840 841 842 843 844 845
    const auto& scale_weights_data =
        ctx.Attr<std::vector<float>>("Scale_weights");
    const bool is_multi_channel = scale_weights_data.size() > 1;
    const int& groups = ctx.Attr<int>("groups");
    int mask_reorder =
        is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
    auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
846
        filter, groups, false, true, scale_weights_data, mask_reorder);
847

A
Adam Osewski 已提交
848 849 850
    std::shared_ptr<dnnl::memory> dst_memory_p;
    if (fuse_residual_conn) {
      auto* residual_param = ctx.Input<Tensor>("ResidualData");
851
      PADDLE_ENFORCE_EQ(
A
Adam Osewski 已提交
852 853 854 855 856 857
          output->dims(), residual_param->dims(),
          platform::errors::InvalidArgument(
              "Output and elementwise parameter need to have the "
              "same dimension sizes, but got output's dimension = %d"
              " and residual param's dimension =%d .",
              output->dims().size(), residual_param->dims().size()));
858
      dst_memory_p =
A
Adam Osewski 已提交
859 860
          handler.AcquireDstMemoryWithResidual(output, residual_param);
      need_s8_to_u8 = (platform::MKLDNNGetDataType<T_out>() ==
861
                       dnnl::memory::data_type::s8) &&
A
Adam Osewski 已提交
862 863 864 865
                      unsigned_output;
    } else {
      dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
    }
L
lidanqing 已提交
866

A
Adam Osewski 已提交
867 868 869
    auto conv_p = handler.AcquireForwardPrimitive();

    std::unordered_map<int, dnnl::memory> args = {
870 871 872
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};
A
Adam 已提交
873

A
Adam Osewski 已提交
874
    if (bias) {
875
      auto p_scales_tuple = handler.get_int8_bias_scales(ctx);
A
Adam 已提交
876

A
Adam Osewski 已提交
877
      auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
878
          bias, true, std::get<1>(*p_scales_tuple),
879
          std::get<0>(*p_scales_tuple));
880
      args.insert({DNNL_ARG_BIAS, *bias_memory_p});
881
    }
A
Adam Osewski 已提交
882 883 884

    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
    conv_p->execute(astream, args);
A
Adam 已提交
885
    astream.wait();
A
Adam Osewski 已提交
886

887
    if (need_s8_to_u8) {
X
xiaolil1 已提交
888 889
      output->mutable_data<uint8_t>(ctx.GetPlace());
    }
A
Adam Osewski 已提交
890 891 892

    output->set_layout(framework::DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
893
  }
894 895
};

896
template <typename T, typename K>
A
Adam Osewski 已提交
897
class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
898
 public:
A
Adam Osewski 已提交
899
  void Compute(const framework::ExecutionContext& ctx) const override {
900
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
A
Adam Osewski 已提交
901
                      platform::errors::PreconditionNotMet(
902
                          "Operator DNNL ConvGrad must use CPUPlace"));
903 904
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
905 906 907 908
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    const Tensor* input = ctx.Input<Tensor>("Input");
    const Tensor* filter = ctx.Input<Tensor>("Filter");
909 910
    const Tensor* bias =
        ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
911 912 913 914 915 916 917
    const Tensor* output_grad =
        ctx.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));

    if (!input_grad && !filter_grad) return;

918 919 920 921 922
    // TODO(jczaja): Are all tensors really needed?
    ConvMKLDNNHandlerT<T, K, T> handler(
        ctx, dev_ctx, ctx.GetPlace(), input, filter, bias, output_grad,
        filter_grad, input_grad,
        ctx.InputName("Input") + ctx.InputName("Filter"));
923 924

    // create mkldnn memory from input tensors (data/weights)
925
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
926

927 928 929 930 931 932
    if (filter_grad) {
      auto src_memory_p =
          handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input);
      auto diff_dst_memory_p =
          handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
              output_grad);
933

934 935
      // For convoluition with groups write filter grad into
      // oneDNN buffer and then we reorder it into filter_grad tensor
936
      int g = std::max(ctx.Attr<int>("groups"), 1);
937
      auto diff_weights_memory_p =
938 939
          g > 1 ? handler.AcquireDiffWeightsMemory()
                : handler.AcquireDiffWeightsMemory(filter_grad);
940

941
      auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive();
942

A
Adam 已提交
943 944
      // TODO(grygielski) why no bias_diff?
      conv_bwd_weights_p->execute(
945 946 947
          astream, {{DNNL_ARG_SRC, *src_memory_p},
                    {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                    {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
A
Adam 已提交
948
      astream.wait();
949

A
Adam Osewski 已提交
950
      filter_grad->set_layout(framework::DataLayout::kMKLDNN);
951 952
      // in OneDNN groups in convolution are treated as separate dimension
      // which is not the case in paddlepaddle
A
Adam Osewski 已提交
953
      auto filter_fmt = platform::GetMKLDNNFormat(*diff_weights_memory_p);
954 955 956 957

      // For convolution with groups convert from blocked to NCHW
      // otherwise there will be problems in next operators working on this data
      if (g > 1) {
958 959
        dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
            framework::TransToProtoVarType(filter->dtype()));
960 961
        // for 3d conv with groups (six dimensional data reorder to goidhw)
        // for 2d conv with groups (five dimensional data reorder to goihw)
962
        // auto weights_tz = phi::vectorize(filter->dims());
963 964

        auto weights_tz = diff_weights_memory_p->get_desc().dims();
965 966 967
        dnnl::memory::format_tag out_format =
            weights_tz.size() == 6 ? dnnl::memory::format_tag::goidhw
                                   : dnnl::memory::format_tag::goihw;
968 969 970
        platform::ReorderMKLDNNHandler handler(
            weights_tz, framework::TransToProtoVarType(filter->dtype()),
            in_type, mkldnn_engine);
971 972 973 974 975 976
        auto reorder_dst_memory_p =
            handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace());

        auto reorder_p =
            handler.AcquireReorder(reorder_dst_memory_p, diff_weights_memory_p);

977
        {
C
chenjian 已提交
978 979 980
          platform::RecordEvent record_reorder(
              "int_reorder", platform::TracerEventType::UserDefined, 2,
              platform::EventRole::kUniqueOp);
981 982 983 984
          reorder_p->execute(astream, *diff_weights_memory_p,
                             *reorder_dst_memory_p);
          astream.wait();
        }
985 986 987 988

        // So here we have a data in goihw , which can be interpreted as OIHW
        // (OIDHW for conv3d)
        // because filter_grad shape is set for OIHW (OIDHW for conv3d)
989 990 991
        dnnl::memory::format_tag target_format =
            weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw
                                   : dnnl::memory::format_tag::oihw;
992 993 994 995
        filter_grad->set_format(target_format);
      } else {
        filter_grad->set_format(filter_fmt);
      }
996 997
    }
    if (input_grad) {
998 999 1000 1001
      auto weights_memory_p =
          handler.AcquireWeightsMemoryWithReorderFromDataPrimitive(
              filter, ctx.Attr<int>("groups"),
              ctx.Attr<std::vector<int>>("strides").size() == 3U);
1002

1003 1004 1005 1006
      auto diff_dst_memory_p =
          handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive(
              output_grad);
      auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad);
1007

1008
      auto conv_bwd_data_p = handler.AcquireBackwardPrimitive();
1009

A
Adam 已提交
1010
      conv_bwd_data_p->execute(astream,
1011 1012 1013
                               {{DNNL_ARG_WEIGHTS, *weights_memory_p},
                                {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
A
Adam 已提交
1014
      astream.wait();
1015

A
Adam Osewski 已提交
1016 1017
      input_grad->set_layout(framework::DataLayout::kMKLDNN);
      input_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p));
1018
    }
X
xiaolil1 已提交
1019
  }
1020
};
1021

1022 1023 1024 1025 1026
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

X
Xin Pan 已提交
1027 1028 1029
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
1030
                                    ops::ConvMKLDNNOpKernel<float, float>);
1031

1032 1033 1034 1035
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
    conv2d, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kConvMKLDNNFP32,
    ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);

1036 1037
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, U8,
1038
                                    ops::kConvMKLDNNINT8,
1039
                                    ops::ConvMKLDNNOpKernel<uint8_t, float>);
1040 1041 1042

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, S8,
1043
                                    ops::kConvMKLDNNINT8,
1044
                                    ops::ConvMKLDNNOpKernel<int8_t, float>);
X
Xin Pan 已提交
1045 1046 1047 1048

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
1049
                                    ops::ConvMKLDNNGradOpKernel<float, float>);
1050

1051 1052 1053
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
    conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16,
    ops::kConvMKLDNNFP32,
1054 1055
    ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16,
                                paddle::platform::bfloat16>);
1056

1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNOpKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
    depthwise_conv2d, MKLDNN, ::paddle::platform::CPUPlace, BF16,
    ops::kConvMKLDNNFP32,
    ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, U8,
                                    ops::kConvMKLDNNINT8,
                                    ops::ConvMKLDNNOpKernel<uint8_t, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, S8,
                                    ops::kConvMKLDNNINT8,
                                    ops::ConvMKLDNNOpKernel<int8_t, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(depthwise_conv2d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNGradOpKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
    depthwise_conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16,
    ops::kConvMKLDNNFP32,
    ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16, float>);

1087 1088 1089
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
1090
                                    ops::ConvMKLDNNOpKernel<float, float>);
1091 1092 1093 1094

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
1095
                                    ops::ConvMKLDNNGradOpKernel<float, float>);