conv_mkldnn_op.cc 35.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include <unordered_map>
Y
Yu Yang 已提交
16 17
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/memory/malloc.h"
18
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
19
#include "paddle/fluid/platform/mkldnn_reuse.h"
20 21 22 23

namespace paddle {
namespace operators {

24 25 26 27 28 29 30 31
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;

L
lidanqing 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
constexpr int same_scale_mask = 0;
constexpr int o_slice_mask = 1 << 0;                         // 1
constexpr int g_slice_mask = 1 << 1;                         // 2
constexpr int g_o_slice_mask = g_slice_mask | o_slice_mask;  // 3

static int ComputeMask(bool is_multi_channel, int multi_channel_mask) {
  return is_multi_channel ? multi_channel_mask : same_scale_mask;
}

static int ComputeWeightsMask(int is_multi_channel, int g) {
  int multi_channel_mask = g > 1 ? g_o_slice_mask : o_slice_mask;
  return ComputeMask(is_multi_channel, multi_channel_mask);
}

static int ComputeBiasMask(int is_multi_channel) {
  return ComputeMask(is_multi_channel, o_slice_mask);
}

inline void GetWeightsTz(std::vector<int>& weights_tz, int groups) {  // NOLINT
Y
Yihua Xu 已提交
51
  if (groups > 1) {
L
lidanqing 已提交
52 53 54 55 56 57
    // if (is_conv3d) [o, i, dimension, h, w]->[g, o/g, i, dimension, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
Y
Yihua Xu 已提交
58 59 60
  }
}

61 62
inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format,
                                           int groups, bool is_conv3d) {
Y
Yihua Xu 已提交
63
  if (is_conv3d) {
64
    return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw;
Y
Yihua Xu 已提交
65
  } else {
66
    return (groups == 1) ? format : MKLDNNMemoryFormat::goihw;
Y
Yihua Xu 已提交
67 68 69
  }
}

L
lidanqing 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
static std::vector<float> ComputeOutputShiftScale(
    const float scale_out_data, const float scale_in_data,
    const std::vector<float>& scale_weights_data) {
  int count = scale_weights_data.size();
  std::vector<float> output_shift_scale(count);
#pragma omp parallel for
  for (int i = 0; i < count; i++) {
    if (scale_weights_data[i] == 0.0) {
      output_shift_scale[i] = scale_out_data;
    } else {
      output_shift_scale[i] =
          static_cast<float>(static_cast<double>(scale_out_data) /
                             (static_cast<double>(scale_in_data) *
                              static_cast<double>(scale_weights_data[i])));
    }
  }
  return output_shift_scale;
}

static std::vector<float> ComputeBiasScale(
    const float scale_in_data, const std::vector<float>& scale_weights_data) {
  int count = scale_weights_data.size();
  std::vector<float> scale_bias_data(count);
#pragma omp parallel for if (count > 1)
  for (int i = 0; i < count; i++) {
    scale_bias_data[i] = scale_in_data * scale_weights_data[i];
  }
  return scale_bias_data;
}

100 101
static mkldnn::memory::data_type GetDstType(bool is_int8,
                                            bool force_fp32_output,
102
                                            std::string fuse_activation,
103 104 105
                                            bool fuse_residual_conn,
                                            const Tensor* residual_param) {
  auto dst_dt = mkldnn::memory::data_type::f32;  // uint8_t, int8_t, float
L
lidanqing 已提交
106
  if (is_int8 && !force_fp32_output) {
107
    if (fuse_residual_conn && residual_param) {
L
lidanqing 已提交
108 109
      // when residual exists, dst_dt will follow the residual_param type,
      // but output will to be set to u8 if relu exists
110
      auto residual_dt = framework::ToMKLDNNDataType(residual_param->type());
L
lidanqing 已提交
111 112 113 114 115 116
      dst_dt = residual_dt;
    } else {
      // when residual does not exist, if (b)relu exist s8 else s8
      dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
                   ? mkldnn::memory::data_type::u8
                   : mkldnn::memory::data_type::s8;
117 118 119 120 121
    }
  }
  return dst_dt;
}

L
lidanqing 已提交
122
template <typename T>
123
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
124 125 126 127
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");
128 129 130 131 132
    bool is_INT8 =
        std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
    if (!is_INT8) {
      ComputeFP32(ctx);
    } else {
133
      std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
134 135 136
      bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
      bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
      auto residual_param = ctx.Input<Tensor>("ResidualData");
137
      auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
138 139 140 141 142 143 144 145
                               fuse_residual_conn, residual_param);
      if (dst_dt == mkldnn::memory::data_type::f32) {
        ComputeINT8<float>(ctx);
      } else if (dst_dt == mkldnn::memory::data_type::u8) {
        ComputeINT8<uint8_t>(ctx);
      } else if (dst_dt == mkldnn::memory::data_type::s8) {
        ComputeINT8<int8_t>(ctx);
      }
146 147
    }
  }
148

149
  void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const {
K
Krzysztof Binias 已提交
150 151
    const bool is_test = ctx.Attr<bool>("is_test");

152 153
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
154 155 156 157
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
158
    auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
159 160
    auto* output = ctx.Output<Tensor>("Output");

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input tensor");
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
                      "Wrong format set for Input tensor");

    PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Filter tensor");
    PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef,
                      "Wrong format set for Filter tensor");

    PADDLE_ENFORCE_GE(
        input->dims().size(), 4,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
    PADDLE_ENFORCE_LE(
        input->dims().size(), 5,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");

    PADDLE_ENFORCE_GE(
        filter->dims().size(), 4,
        "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
    PADDLE_ENFORCE_LE(
        filter->dims().size(), 5,
        "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");

185
    if (bias) {
186 187 188 189 190 191 192
      PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
                        "Wrong layout set for Bias tensor");
      PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef,
                        "Wrong format set for Bias tensor");

      PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
                        "Bias must only have 1 dimension, i.e. X");
193
    }
194 195 196 197

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
198 199 200
    std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
    float fuse_alpha = ctx.Attr<float>("fuse_alpha");
    float fuse_beta = ctx.Attr<float>("fuse_beta");
201
    bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
202
    int groups = ctx.Attr<int>("groups");
203
    bool is_conv3d = strides.size() == 3U;
204

205
    PADDLE_ENFORCE(
206 207 208 209
        is_conv3d
            ? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
                  dilations[2] == 1
            : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
210 211 212 213 214
        "dilation in convolution is not implemented yet");

    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();

215 216
    auto src_tz = paddle::framework::vectorize<int>(input->dims());
    auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
217
    int g = std::max(groups, 1);
L
lidanqing 已提交
218
    GetWeightsTz(weights_tz, g);
219
    auto dst_tz = paddle::framework::vectorize<int>(output->dims());
220

221
    // Get unique name for storing MKLDNN primitives
222
    const std::string key = platform::CreateKey(
223
        src_tz, weights_tz, fuse_activation, strides, paddings, dilations,
224
        groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
225 226 227

    std::vector<primitive> pipeline;

228
    auto src_format = input->format();
229
    MKLDNNMemoryFormat weights_format =
230 231 232 233 234 235
        GetWeightsFormat(filter->format(), g, is_conv3d);

    auto user_src_md = platform::MKLDNNMemDesc(
        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
    auto user_weights_md = platform::MKLDNNMemDesc(
        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
236 237 238 239 240

    /* create memory descriptor for convolution without specified format
     * ('any') which lets a primitive (convolution in this case) choose
     * the memory format preferred for best performance
     */
241 242 243 244
    std::string data_format = ctx.Attr<std::string>("data_format");
    auto chosen_memory_format =
        platform::data_format_to_memory_format(data_format);

245
    weights_format = MKLDNNMemoryFormat::any;
246
    // Check the format for user's special output
247
    if (chosen_memory_format != MKLDNNMemoryFormat::any) {
248 249 250 251
      if (is_conv3d) {
        chosen_memory_format =
            platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
      }
252 253
    }

254
    auto src_md = platform::MKLDNNMemDesc(
255
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
256
    auto weights_md = platform::MKLDNNMemDesc(
257
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
258
    std::vector<int> bias_tz;
259
    auto dst_md = platform::MKLDNNMemDesc(
260
        dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
261

262 263
    platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

264
    // create a conv primitive descriptor and save it for usage in backward
265
    std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
266 267
    auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
                                 : mkldnn::prop_kind::forward_training;
268
    if (bias) {
269
      bias_tz = paddle::framework::vectorize<int>(bias->dims());
270
      auto bias_md = platform::MKLDNNMemDesc(
271
          bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
272
      conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
273
          src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
274
          fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
275
          fwd_prop_kind);
276
    } else {
277 278
      conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
          src_md, weights_md, boost::none, dst_md, strides, paddings,
279 280
          mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
          fuse_residual_conn, fwd_prop_kind);
281
    }
282

283
    // create mkldnn memory from input tensors (data/weights)
284 285
    auto user_src_memory_p =
        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
286
    auto user_weights_memory_p = handler.AcquireWeightsMemory(
287
        user_weights_md, to_void_cast<T>(filter_data));
288

289 290 291 292 293
    // create reorder primitive if the input format is not the preferred one
    auto src_memory_p =
        handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
    auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
        user_weights_memory_p, pipeline, is_test);
294

295
    std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p;
296

297
    if (fuse_residual_conn) {
298 299
      auto residual_param = ctx.Input<Tensor>("ResidualData");
      auto residual_param_data = residual_param->data<T>();
300

301 302 303 304 305 306
      PADDLE_ENFORCE(
          residual_param_data != nullptr,
          "Provide data if you want MKLDNN conv+elementwise_add fusion");
      PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
                        "Output and elementwise parameter need to have the "
                        "same dimension sizes");
307

308
      if (residual_param->format() != handler.GetDstFormat()) {
309 310
        auto output_data =
            output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
311
        auto residual_data_tz =
312
            paddle::framework::vectorize<int>(residual_param->dims());
313 314 315 316 317
        auto residual_data_type =
            paddle::framework::ToMKLDNNDataType(residual_param->type());

        auto user_residual_md = platform::MKLDNNMemDesc(
            residual_data_tz, residual_data_type, residual_param->format());
318
        user_residual_memory_p = handler.AcquireResidualDataMemory(
319
            user_residual_md, to_void_cast<T>(residual_param_data));
320 321 322

        dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
            user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
323 324
      } else {
        output->ShareDataWith(*residual_param);
325 326 327
        auto output_data = output->mutable_data<T>(ctx.GetPlace());
        dst_memory_p =
            handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
328
      }
329
    } else {
330 331
      auto output_data =
          output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
332 333
      dst_memory_p =
          handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
334
    }
335 336

    // create convolution op primitive
337
    std::shared_ptr<mkldnn::convolution_forward> conv_p;
338
    std::shared_ptr<mkldnn::memory> user_bias_memory_p, bias_memory_p;
339 340 341
    if (bias) {
      const T* bias_data = bias->data<T>();
      auto user_bias_md = platform::MKLDNNMemDesc(
342
          {bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
343
      user_bias_memory_p =
344 345
          handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));

346
      bias_memory_p =
347 348 349 350 351 352 353
          handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
      conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
                                          bias_memory_p, dst_memory_p);
    } else {
      conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
                                          dst_memory_p);
    }
354 355

    // push primitive to stream and wait until it's executed
356
    pipeline.push_back(*conv_p);
357 358
    stream(stream::kind::eager).submit(pipeline).wait();

359 360
    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(GetMKLDNNFormat(*dst_memory_p));
361
  }
L
lidanqing 已提交
362

363
  template <typename T_out>
364 365 366 367 368 369 370 371 372 373 374 375
  void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
    const bool is_test = ctx.Attr<bool>("is_test");

    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
    auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
    auto* output = ctx.Output<Tensor>("Output");

376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input tensor");
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
                      "Wrong format set for Input tensor");

    PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Filter tensor");
    PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef,
                      "Wrong format set for Filter tensor");

    PADDLE_ENFORCE_GE(
        input->dims().size(), 4,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
    PADDLE_ENFORCE_LE(
        input->dims().size(), 5,
        "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");

    PADDLE_ENFORCE_GE(
        filter->dims().size(), 4,
        "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
    PADDLE_ENFORCE_LE(
        filter->dims().size(), 5,
        "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");

400
    if (bias) {
401 402 403 404 405 406 407
      PADDLE_ENFORCE_EQ(bias->layout(), DataLayout::kMKLDNN,
                        "Wrong layout set for Bias tensor");
      PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::format_undef,
                        "Wrong format set for Bias tensor");

      PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
                        "Bias must only have 1 dimension, i.e. X");
408 409 410 411 412 413
    }

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    int groups = ctx.Attr<int>("groups");
414 415 416
    std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
    float fuse_alpha = ctx.Attr<float>("fuse_alpha");
    float fuse_beta = ctx.Attr<float>("fuse_beta");
X
xiaolil1 已提交
417
    bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
418
    bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
419 420
    bool unsigned_output =
        (fuse_activation == "relu" || fuse_activation == "relu6");
L
lidanqing 已提交
421 422 423 424 425
    auto scale_in_data = ctx.Attr<float>("Scale_in");
    auto scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
    auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
    auto scale_out_data =
        force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
426 427 428 429

    PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output,
                   "residual fusion does not support force output with fp32");

430 431 432 433 434 435 436
    bool is_conv3d = strides.size() == 3U;
    PADDLE_ENFORCE(
        is_conv3d
            ? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
                  dilations[2] == 1
            : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
        "dilation in convolution is not implemented yet");
X
xiaolil1 已提交
437

L
lidanqing 已提交
438 439
    PADDLE_ENFORCE_NE(is_conv3d, true,
                      "int8 does not support conv3d currently");
440 441 442

    const T* input_data = input->data<T>();

443 444
    auto src_tz = paddle::framework::vectorize<int>(input->dims());
    auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
445
    int g = std::max(groups, 1);
L
lidanqing 已提交
446
    GetWeightsTz(weights_tz, g);
447
    auto dst_tz = paddle::framework::vectorize<int>(output->dims());
448

X
xiaolil1 已提交
449 450
    mkldnn::memory::data_type src_dt =
        paddle::framework::ToMKLDNNDataType(input->type());
451

L
lidanqing 已提交
452
    std::string key = platform::CreateKey(
453
        src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
454
        input->format(), fuse_activation, fuse_residual_conn,
455
        ctx.op().Input("Input") + ctx.op().Input("Filter"));
456

457 458 459
    std::shared_ptr<mkldnn::convolution_forward> conv_p;
    std::shared_ptr<mkldnn::memory> src_memory_p;
    std::shared_ptr<mkldnn::memory> user_src_memory_p;
460
    std::vector<primitive> pipeline;
461
    std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
L
lidanqing 已提交
462
    std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p;
463

L
lidanqing 已提交
464 465
    const float* filter_data = filter->data<float>();
    bool is_multi_channel = scale_weights_data.size() > 1;
466

L
lidanqing 已提交
467 468
    auto output_shift_scale = ComputeOutputShiftScale(
        scale_out_data, scale_in_data, scale_weights_data);
469

L
lidanqing 已提交
470 471 472 473 474 475 476 477
    float scale_residual =
        fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
    auto user_src_md =
        platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
    auto user_weights_md = platform::MKLDNNMemDesc(
        {weights_tz}, platform::MKLDNNGetDataType<float>(),
        ((g) == 1) ? mkldnn::memory::format::oihw
                   : mkldnn::memory::format::goihw);
478

L
lidanqing 已提交
479 480 481 482 483 484 485
    /* create memory descriptor for convolution without specified format
    * ('any') which lets a primitive (convolution in this case) choose
    * the memory format preferred for best performance
    */
    std::string data_format = ctx.Attr<std::string>("data_format");
    auto chosen_memory_format =
        platform::data_format_to_memory_format(data_format);
486

L
lidanqing 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
    auto src_md = platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
    auto weights_md = platform::MKLDNNMemDesc(weights_tz, memory::data_type::s8,
                                              chosen_memory_format);
    auto dst_md = platform::MKLDNNMemDesc(
        dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);

    platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
    auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
                               : mkldnn::prop_kind::forward_training;

    std::vector<int> bias_tz;

    if (bias) {
      bias_tz = paddle::framework::vectorize<int>(bias->dims());
      auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
                                             mkldnn::memory::format::x);
      conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
          src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
          fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
          propagation, output_shift_scale, scale_residual);
507
    } else {
L
lidanqing 已提交
508 509 510 511 512
      conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
          src_md, weights_md, boost::none, dst_md, strides, paddings,
          mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
          fuse_residual_conn, propagation, output_shift_scale, scale_residual);
    }
513

L
lidanqing 已提交
514 515 516 517 518
    // create mkldnn memory from input tensors (data/weights)
    user_src_memory_p =
        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
    auto user_weights_memory_p = handler.AcquireWeightsMemory(
        user_weights_md, to_void_cast<float>(filter_data));
X
xiaolil1 已提交
519

L
lidanqing 已提交
520 521 522
    // create reorder primitive if the input format is not the preferred one
    src_memory_p =
        handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
X
xiaolil1 已提交
523

L
lidanqing 已提交
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
    std::shared_ptr<mkldnn::memory> weights_memory_p;

    int mask_reorder = ComputeWeightsMask(is_multi_channel, g);

    weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
        user_weights_memory_p, pipeline, is_test, true, scale_weights_data,
        mask_reorder);

    if (fuse_residual_conn) {
      auto residual_param = ctx.Input<Tensor>("ResidualData");
      auto residual_param_data = residual_param->data<T_out>();
      PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
                        "Output and elementwise parameter need to have the "
                        "same dimension sizes");
      auto residual_dt =
          paddle::framework::ToMKLDNNDataType(residual_param->type());
      if (residual_param->format() != handler.GetDstFormat()) {
        auto residual_data_tz =
            paddle::framework::vectorize<int>(residual_param->dims());
        auto user_residual_md = platform::MKLDNNMemDesc(
            residual_data_tz, residual_dt, residual_param->format());

        user_residual_memory_p = handler.AcquireResidualDataMemory(
            user_residual_md, to_void_cast<T_out>(residual_param_data));
X
xiaolil1 已提交
548

L
lidanqing 已提交
549 550 551 552 553 554 555 556 557
        T_out* output_data = output->mutable_data<T_out>(ctx.GetPlace());
        dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
            user_residual_memory_p, to_void_cast<T_out>(output_data), pipeline);

      } else {
        output->ShareDataWith(*residual_param);
        auto output_data = output->mutable_data<T_out>(ctx.GetPlace());
        dst_memory_p = handler.AcquireDstMemoryFromPrimitive(
            to_void_cast<T_out>(output_data));
X
xiaolil1 已提交
558
      }
L
lidanqing 已提交
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
    } else {
      T_out* output_data = output->mutable_data<T_out>(
          ctx.GetPlace(), handler.GetDstMemorySize());
      dst_memory_p = handler.AcquireDstMemoryFromPrimitive(
          to_void_cast<T_out>(output_data));
    }

    // create convolution op primitive
    if (bias) {
      const float* bias_data = bias->data<float>();
      auto user_bias_md = platform::MKLDNNMemDesc(
          {bias_tz}, platform::MKLDNNGetDataType<float>(), memory::format::x);
      auto user_bias_memory_p = handler.AcquireBiasMemory(
          user_bias_md, to_void_cast<float>(bias_data));
      std::shared_ptr<mkldnn::memory> bias_memory_p;

      auto scale_bias_data =
          ComputeBiasScale(scale_in_data, scale_weights_data);
      int mask_bias_reorder = ComputeBiasMask(is_multi_channel);
      bias_memory_p = handler.AcquireBiasMemoryFromPrimitive(
          user_bias_memory_p, pipeline, is_test, true, scale_bias_data,
          mask_bias_reorder);
      conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
                                          bias_memory_p, dst_memory_p);
    } else {
      conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
                                          dst_memory_p);
586
    }
L
lidanqing 已提交
587 588 589
    // push primitive to stream and wait until it's executed
    pipeline.push_back(*conv_p);

590 591
    // push primitive to stream and wait until it's executed
    stream(stream::kind::eager).submit(pipeline).wait();
L
lidanqing 已提交
592 593
    if (platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8 &&
        unsigned_output) {
X
xiaolil1 已提交
594 595
      output->mutable_data<uint8_t>(ctx.GetPlace());
    }
596 597 598
    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(GetMKLDNNFormat(*dst_memory_p));
  }
599 600 601
};

template <typename T>
602
class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
603 604 605 606 607
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                   "It must use CPUPlace.");

608 609
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
610 611 612 613 614 615 616 617 618
    const auto& mkldnn_engine = dev_ctx.GetEngine();

    const Tensor* input = ctx.Input<Tensor>("Input");
    const Tensor* filter = ctx.Input<Tensor>("Filter");
    const Tensor* output_grad =
        ctx.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));

619 620 621 622
    PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input tensor");
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::format_undef,
                      "Wrong format set for Input tensor");
623

624 625 626 627 628 629 630 631 632 633 634 635
    PADDLE_ENFORCE_EQ(filter->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Filter tensor");
    PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::format_undef,
                      "Wrong format set for Filter tensor");

    PADDLE_ENFORCE_EQ(output_grad->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for output_grad tensor");
    PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::format_undef,
                      "Wrong format set for output_grad tensor");

    PADDLE_ENFORCE_EQ(
        ctx.Attr<bool>("is_test"), false,
636 637
        "is_test attribute should be set to False in training phase.");

638 639 640 641
    if (!input_grad && !filter_grad) return;

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
642 643
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    int groups = ctx.Attr<int>("groups");
644

645
    bool is_conv3d = strides.size() == 3U;
646 647 648 649 650 651
    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();
    const T* output_grad_data = output_grad->data<T>();
    T* input_grad_data = nullptr;
    T* filter_grad_data = nullptr;

652 653
    auto src_tz = paddle::framework::vectorize<int>(input->dims());
    auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
654
    int g = std::max(groups, 1);
L
lidanqing 已提交
655
    GetWeightsTz(weights_tz, g);
656
    auto dst_tz = paddle::framework::vectorize<int>(output_grad->dims());
657
    auto src_format = input->format();
658
    MKLDNNMemoryFormat weights_format =
Y
Yihua Xu 已提交
659
        GetWeightsFormat(filter->format(), g, is_conv3d);
660

661
    // Get an unique name from "argument" name of "input" and "Filter" variable
J
Jacek Czaja 已提交
662
    // as well as attributes of primitive to be created
663
    // This name will be used as key when saving info into device context
664
    const std::string key = platform::CreateKey(
665 666
        src_tz, weights_tz, "", strides, paddings, dilations, groups,
        ctx.op().Input("Input") + ctx.op().Input("Filter"));
667 668

    const std::string key_conv_pd = key + "@conv_pd";
669
    std::vector<primitive> pipeline;
670

671 672
    // Create user memory descriptors
    auto user_src_md = platform::MKLDNNMemDesc(
673
        {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
674
    auto user_weights_md = platform::MKLDNNMemDesc(
675
        {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
676 677
    auto user_diff_dst_md = platform::MKLDNNMemDesc(
        {dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
678 679 680 681 682

    /* create memory descriptor for conv backward without specified format
     * ('any') which lets a primitive (conv backward in this case) choose
     * the memory format preferred for best performance
     */
683 684 685 686
    std::string data_format = ctx.Attr<std::string>("data_format");
    auto chosen_memory_format =
        platform::data_format_to_memory_format(data_format);

687
    weights_format = MKLDNNMemoryFormat::any;
688
    // Check the format for user's special output
689
    if (chosen_memory_format != MKLDNNMemoryFormat::any) {
690 691 692 693
      if (is_conv3d) {
        chosen_memory_format =
            platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
      }
694 695
    }

696
    auto src_md = platform::MKLDNNMemDesc(
697
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
698
    auto diff_src_md = platform::MKLDNNMemDesc(
699
        src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
700
    auto weights_md = platform::MKLDNNMemDesc(
701
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
702
    auto diff_weights_md = platform::MKLDNNMemDesc(
703
        weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
704
    auto diff_dst_md = platform::MKLDNNMemDesc(
705
        dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
706

707
    // Retrieve conv_pd from device context
708 709 710
    auto conv_pd =
        std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
            dev_ctx.GetBlob(key_conv_pd));
711 712 713
    PADDLE_ENFORCE(conv_pd != nullptr,
                   "Fail to find conv_pd in device context");

714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729
    // create backward convolution weights primitive descriptor
    auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
        mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md,
        strides, paddings, paddings, mkldnn::padding_kind::zero);
    auto conv_bwd_weights_pd =
        std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
            conv_bwd_weights_desc, mkldnn_engine, *conv_pd);

    // create backward convolution data primitive descriptor
    auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
        mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md,
        strides, paddings, paddings, mkldnn::padding_kind::zero);
    auto conv_bwd_data_pd =
        std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
            conv_bwd_data_desc, mkldnn_engine, *conv_pd);

J
Jacek Czaja 已提交
730 731 732
    platform::ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd,
                                        conv_bwd_weights_pd, dev_ctx,
                                        mkldnn_engine, key);
733 734 735 736 737 738 739 740 741

    // create mkldnn memory from input tensors (data/weights)
    auto user_src_memory_p =
        handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
    auto user_weights_memory_p = handler.AcquireWeightsMemory(
        user_weights_md, to_void_cast<T>(filter_data));
    auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
        user_diff_dst_md, to_void_cast<T>(output_grad_data));

742 743
    // create backward conv primitive for weights
    if (filter_grad) {
744 745
      auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
          user_src_memory_p, pipeline);
746

747 748 749 750
      auto diff_dst_memory_4filter_p =
          handler.AcquireDiffDstMemoryFromWeightsPrimitive(
              user_diff_dst_memory_p, pipeline);

751
      const size_t size = handler.GetDiffWeightsMemorySize();
752
      filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
753

754 755 756 757 758 759 760 761 762
      auto diff_weights_memory_p =
          handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
              reinterpret_cast<void*>(filter_grad_data));

      auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights(
          src_memory_p, diff_dst_memory_4filter_p, diff_weights_memory_p);

      // push primitive to stream and wait until it's executed
      pipeline.push_back(*conv_bwd_weights_p);
763

764 765
      filter_grad->set_layout(DataLayout::kMKLDNN);
      filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
766 767 768
    }

    if (input_grad) {
769 770 771 772 773 774 775
      auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
          user_weights_memory_p, pipeline);

      auto diff_dst_memory_4data_p =
          handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
                                                        pipeline);

776
      const size_t size = handler.GetDiffSourceMemorySize();
777
      input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
778

779 780 781 782 783 784 785
      auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
          reinterpret_cast<void*>(input_grad_data));

      auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData(
          diff_dst_memory_4data_p, weights_memory_p, diff_src_memory_p);

      pipeline.push_back(*conv_bwd_data_p);
786

787 788
      input_grad->set_layout(DataLayout::kMKLDNN);
      input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
789
    }
790
    stream(stream::kind::eager).submit(pipeline).wait();
X
xiaolil1 已提交
791
  }
792 793 794 795 796 797
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

X
Xin Pan 已提交
798 799 800
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
L
lidanqing 已提交
801
                                    ops::ConvMKLDNNOpKernel<float>);
802 803 804

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, U8,
805
                                    ops::kConvMKLDNNINT8,
L
lidanqing 已提交
806
                                    ops::ConvMKLDNNOpKernel<uint8_t>);
807 808 809

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
                                    ::paddle::platform::CPUPlace, S8,
810
                                    ops::kConvMKLDNNINT8,
L
lidanqing 已提交
811
                                    ops::ConvMKLDNNOpKernel<int8_t>);
X
Xin Pan 已提交
812 813 814 815 816

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNGradOpKernel<float>);
817 818 819 820

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
L
lidanqing 已提交
821
                                    ops::ConvMKLDNNOpKernel<float>);
822 823 824 825 826

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
                                    ::paddle::platform::CPUPlace, FP32,
                                    ops::kConvMKLDNNFP32,
                                    ops::ConvMKLDNNGradOpKernel<float>);