mul_mkldnn_op.cc 22.3 KB
Newer Older
P
Physher 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
W
wanghuancoder 已提交
16

17
#include "paddle/fluid/framework/op_registry.h"
18
#include "paddle/fluid/platform/mkldnn_reuse.h"
W
wanghuancoder 已提交
19

20
namespace phi {
21
class DenseTensor;
22
}  // namespace phi
23

W
wanghuancoder 已提交
24
namespace paddle {
25
namespace framework {}  // namespace framework
W
wanghuancoder 已提交
26 27 28 29
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle
P
Physher 已提交
30 31 32 33 34 35 36

namespace paddle {
namespace operators {

using framework::DataLayout;
using framework::DDim;
using framework::ExecutionContext;
37
using framework::LoDTensor;
P
Physher 已提交
38
using framework::Tensor;
39 40 41 42 43

using platform::MatMulV2MKLDNNHandler;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;

44 45 46 47
using dnnl::inner_product_forward;
using dnnl::memory;
using dnnl::prop_kind;
using dnnl::stream;
P
Physher 已提交
48

49 50 51
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;

P
Physher 已提交
52 53 54
template <typename XT, typename YT, typename OT>
class MulPrimitiveFactory {
 public:
55
  explicit MulPrimitiveFactory(const dnnl::engine &engine) : engine_(engine) {}
P
Physher 已提交
56

57 58 59 60 61
  inner_product_forward CreateMulPrimitive(const Tensor *x_input,
                                           const Tensor *y_input,
                                           Tensor *output,
                                           const ExecutionContext &ctx) {
    /* check data format and reorder if need */
P
Physher 已提交
62 63 64
    int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
    int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");

65 66 67
    // TODO(intel-minghui) : Remove the restriction that only supports Input(Y)
    // as weights
    PADDLE_ENFORCE_EQ(
68 69
        (std::is_same<YT, float>::value),
        true,
70 71 72 73 74 75
        platform::errors::InvalidArgument(
            "Input(Y) must be fp32 data type since only fp32 data type is "
            "supported in the current design of MKLDNN INT8."));

    auto x_matrix = UpdateDataFormat<XT>(x_input, x_num_col_dims, ctx);
    auto y_matrix = UpdateDataFormat<YT>(y_input, y_num_col_dims, ctx);
P
Physher 已提交
76 77 78 79 80 81 82 83

    auto output_dim = output->dims();
    if (output_dim.size() != 2) {
      output->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
    }

    if (mul_) {
      UpdateDataPointers(ctx, output, &x_matrix);
A
Adam 已提交
84
      Execute();
85
      return *(mul_);
P
Physher 已提交
86 87
    }

88
    auto src_desc = CreateMemDescriptor<XT>(&x_matrix, MKLDNNMemoryFormat::nc);
P
Physher 已提交
89
    x_input_ = CreateMemory<XT>(src_desc, &x_matrix);
90 91 92 93 94 95 96 97 98

    if (is_int8_) {
      const auto trans_y = TransposeInputY(&y_matrix);
      auto scale_y = ctx.Attr<std::vector<float>>("scale_y");
      y_input_ = QuantInputY(trans_y, scale_y);
    } else {
      y_input_ = TransposeInputY(&y_matrix);
    }

99
    auto dst_desc = CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any);
P
Physher 已提交
100 101

    mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx);
A
Adam 已提交
102
    Execute();
103 104 105 106 107
    return *(mul_);
  }

 private:
  memory ReorderWithScale(const memory::desc &src_desc,
108 109
                          const memory::desc &dst_desc,
                          void *src_data,
110 111
                          const std::vector<float> &scale) {
    auto mask = scale.size() > 1 ? 1 : 0;
112
    dnnl::primitive_attr attr;
113 114 115 116 117
    attr.set_output_scales(mask, scale);

    auto src_mem = memory(src_desc, engine_, src_data);
    auto dst_mem = memory(dst_desc, engine_);

118
    auto reorder_pd = dnnl::reorder::primitive_desc(src_mem, dst_mem, attr);
119

120
    auto reorder = dnnl::reorder(reorder_pd);
121

122
    auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
123
    {
C
chenjian 已提交
124
      platform::RecordEvent record_reorder(
125 126 127
          "int_reorder",
          platform::TracerEventType::UserDefined,
          2,
C
chenjian 已提交
128
          platform::EventRole::kUniqueOp);
129 130 131
      reorder.execute(astream, src_mem, dst_mem);
      astream.wait();
    }
132 133 134 135 136 137 138 139 140 141 142 143

    return dst_mem;
  }

  memory QuantInputY(memory input_y, const std::vector<float> &scale_y) {
    const auto &dims = input_y.get_desc().data.dims;
    auto ndims = input_y.get_desc().data.ndims;
    auto y_dims = std::vector<int64_t>(dims, dims + ndims);

    auto user_y_desc = CreateMemDescriptor<YT>(y_dims, MKLDNNMemoryFormat::oi);
    auto y_desc = CreateMemDescriptor<int8_t>(y_dims, MKLDNNMemoryFormat::oi);

144 145
    return ReorderWithScale(
        user_y_desc, y_desc, input_y.get_data_handle(), scale_y);
146 147
  }

148 149 150
  dnnl::primitive_attr CreateMulAttr(const ExecutionContext &ctx,
                                     bool force_fp32_output) {
    dnnl::primitive_attr mul_attr;
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

    auto scale_y_data = ctx.Attr<std::vector<float>>("scale_y");
    auto scale_x_data = ctx.Attr<float>("scale_x");
    auto scale_out_data =
        force_fp32_output ? 1.0f : ctx.Attr<float>("scale_out");

    bool is_multi_channel = scale_y_data.size() > 1;
    int count = is_multi_channel ? scale_y_data.size() : 1;
    std::vector<float> output_shift_scale(count);
    for (int i = 0; i < count; i++) {
      if (scale_y_data[i] == 0.0)
        output_shift_scale[i] = scale_out_data;
      else
        output_shift_scale[i] =
            scale_out_data / (scale_x_data * scale_y_data[i]);
    }
    int mul_mask = is_multi_channel ? 1 : 0;
    mul_attr.set_output_scales(mul_mask, output_shift_scale);

    return mul_attr;
  }

  inner_product_forward CreateMulPrimitive(const memory &x_memory,
                                           const memory &y_memory,
                                           const memory::desc &dst_desc,
                                           Tensor *output,
                                           const ExecutionContext &ctx) {
    const auto x_desc = x_memory.get_desc();
    const auto y_desc = y_memory.get_desc();
    inner_product_forward::primitive_desc mul_prim_desc;

    const auto &mul_desc = inner_product_forward::desc(
        prop_kind::forward, x_desc, y_desc, dst_desc);

    if (is_int8_) {
      bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
      auto mul_attr = CreateMulAttr(ctx, force_fp32_output);
      mul_prim_desc =
          inner_product_forward::primitive_desc(mul_desc, mul_attr, engine_);
    } else {
      mul_prim_desc = inner_product_forward::primitive_desc(mul_desc, engine_);
    }

    output_ = CreateDstMemory(mul_prim_desc, ctx, output);

    return inner_product_forward(mul_prim_desc);
P
Physher 已提交
197 198
  }

A
Adam 已提交
199
  void Execute() {
200
    auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
201 202 203 204
    (*mul_).execute(astream,
                    {{DNNL_ARG_SRC, *x_input_},
                     {DNNL_ARG_WEIGHTS, *y_input_},
                     {DNNL_ARG_DST, *output_}});
A
Adam 已提交
205 206 207
    astream.wait();
  }

P
Physher 已提交
208
  template <typename T>
209 210
  Tensor UpdateDataFormat(const Tensor *data,
                          int num_col_dims,
P
Physher 已提交
211 212 213
                          const ExecutionContext &ctx) {
    Tensor x_tmp;
    Tensor data_matrix;
214 215
    MKLDNNMemoryFormat src_fmt = data->format();
    MKLDNNMemoryFormat dst_fmt;
P
Physher 已提交
216 217 218
    auto src_mdesc = CreateMemDescriptor<T>(data, src_fmt);

    if ((data->dims().size() == 4 &&
219
         src_fmt != (dst_fmt = MKLDNNMemoryFormat::nchw)) ||
P
Physher 已提交
220
        (data->dims().size() == 5 &&
221
         src_fmt != (dst_fmt = MKLDNNMemoryFormat::ncdhw))) {
P
Physher 已提交
222 223 224
      auto dst_mdesc = CreateMemDescriptor<T>(data, dst_fmt);
      x_tmp.mutable_data<T>(ctx.GetPlace(), data->memory_size());

225 226 227
      Reorder(src_mdesc,
              dst_mdesc,
              to_void_cast<T>(data->data<T>()),
P
Physher 已提交
228 229 230
              to_void_cast<T>(x_tmp.data<T>()));

      x_tmp.Resize(data->dims());
A
Adam 已提交
231
      x_tmp.set_format(platform::GetMKLDNNFormat(dst_mdesc));
P
Physher 已提交
232 233 234 235 236 237 238 239
      data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims);
    } else {
      data_matrix = framework::ReshapeToMatrix(*data, num_col_dims);
    }

    return data_matrix;
  }

240 241
  void UpdateDataPointers(const ExecutionContext &ctx,
                          Tensor *out,
P
Physher 已提交
242 243 244 245
                          const Tensor *in) {
    x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
    output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));

A
Adam 已提交
246
    if (out->format() == MKLDNNMemoryFormat::undef) {
A
Adam 已提交
247
      auto output_format = platform::GetMKLDNNFormat(*output_);
248
      out->set_format((MKLDNNMemoryFormat)output_format);
P
Physher 已提交
249 250 251 252 253
    }
  }

  template <typename T>
  memory::desc CreateMemDescriptor(
254 255
      const Tensor *tensor,
      MKLDNNMemoryFormat format,
P
Physher 已提交
256
      memory::data_type type = platform::MKLDNNGetDataType<T>()) {
257
    auto dims = phi::vectorize<int64_t>(tensor->dims());
P
Physher 已提交
258 259 260 261 262
    return platform::MKLDNNMemDesc(dims, type, format);
  }

  template <typename T>
  memory::desc CreateMemDescriptor(
263 264
      const std::vector<int64_t> &dims,
      MKLDNNMemoryFormat format,
P
Physher 已提交
265 266 267 268 269 270
      memory::data_type type = platform::MKLDNNGetDataType<T>()) {
    return platform::MKLDNNMemDesc(dims, type, format);
  }

  template <typename T>
  memory CreateMemory(const memory::desc &desc, const Tensor *tensor) {
A
Adam 已提交
271
    return memory(desc, engine_, to_void_cast<T>(tensor->data<T>()));
P
Physher 已提交
272 273 274 275
  }

  memory CreateDstMemory(
      const inner_product_forward::primitive_desc &mul_prim_desc,
276 277
      const ExecutionContext &ctx,
      Tensor *output) {
A
Adam 已提交
278 279
    auto dst_desc = mul_prim_desc.dst_desc();
    auto buffer_size = dst_desc.get_size();
P
Physher 已提交
280 281

    OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
A
Adam 已提交
282 283
    output->set_format(paddle::platform::GetMKLDNNFormat(dst_desc));
    return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
P
Physher 已提交
284 285
  }

286 287 288 289
  memory Reorder(const memory::desc &src_desc,
                 const memory::desc &dst_desc,
                 void *src_data,
                 void *dst_data = NULL) {
A
Adam 已提交
290 291 292
    auto src_mem = memory(src_desc, engine_, src_data);
    auto dst_mem = dst_data ? memory(dst_desc, engine_, dst_data)
                            : memory(dst_desc, engine_);
P
Physher 已提交
293

294
    auto reorder = dnnl::reorder(src_mem, dst_mem);
A
Adam 已提交
295

296
    auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
297
    {
C
chenjian 已提交
298
      platform::RecordEvent record_reorder(
299 300 301
          "int_reorder",
          platform::TracerEventType::UserDefined,
          2,
C
chenjian 已提交
302
          platform::EventRole::kUniqueOp);
303 304 305
      reorder.execute(astream, src_mem, dst_mem);
      astream.wait();
    }
P
Physher 已提交
306 307 308 309 310

    return dst_mem;
  }

  memory TransposeInputY(const Tensor *input_y) {
311
    auto dims = phi::vectorize<int64_t>(input_y->dims());
P
Physher 已提交
312
    std::swap(dims[0], dims[1]);  // Correct output dimensions
313 314
    auto src_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::io);
    auto dst_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::oi);
P
Physher 已提交
315 316 317
    return Reorder(src_desc, dst_desc, to_void_cast<YT>(input_y->data<YT>()));
  }

318
  const dnnl::engine &engine_;
319 320 321 322
  paddle::optional<memory> x_input_;
  paddle::optional<memory> y_input_;
  paddle::optional<memory> output_;
  paddle::optional<inner_product_forward> mul_;
323 324
  static constexpr bool is_int8_ =
      std::is_same<XT, int8_t>::value || std::is_same<XT, uint8_t>::value;
P
Physher 已提交
325 326 327 328 329
};

/* OT: output data type */
template <typename XT, typename YT, typename OT>
std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
330 331 332 333
    const MKLDNNDeviceContext &dev_ctx,
    const ExecutionContext &ctx,
    const Tensor *input_x,
    const Tensor *input_y,
334
    const dnnl::engine &mkldnn_engine) {
335 336 337 338 339 340 341
  std::string key =
      platform::CreateKey(dev_ctx,
                          framework::TransToProtoVarType(input_x->dtype()),
                          phi::vectorize(input_x->dims()),
                          framework::TransToProtoVarType(input_y->dtype()),
                          phi::vectorize(input_y->dims()),
                          ctx.OutputName("Out"));
342
  key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
P
Physher 已提交
343 344 345 346 347 348

  auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>(
      dev_ctx.GetBlob(key));

  if (prim_creator == nullptr) {
    prim_creator =
349
        std::make_shared<MulPrimitiveFactory<XT, YT, OT>>(mkldnn_engine);
P
Physher 已提交
350 351 352 353 354 355 356 357 358 359
    dev_ctx.SetBlob(key, prim_creator);
  }

  return prim_creator;
}

template <typename XT, typename YT>
inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,
                                      const ExecutionContext &ctx,
                                      const Tensor *input_x,
360 361
                                      const Tensor *input_y,
                                      Tensor *output,
362
                                      const dnnl::engine &mkldnn_engine) {
363
  constexpr bool is_int8 =
P
Physher 已提交
364 365 366
      std::is_same<XT, int8_t>::value || std::is_same<XT, uint8_t>::value;
  bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");

367
  if (is_int8 && !force_fp32_output) {
368 369
    return GetPrimitiveFactory<XT, YT, int8_t>(
               dev_ctx, ctx, input_x, input_y, mkldnn_engine)
P
Physher 已提交
370 371 372
        ->CreateMulPrimitive(input_x, input_y, output, ctx);

  } else {
373 374
    return GetPrimitiveFactory<XT, YT, float>(
               dev_ctx, ctx, input_x, input_y, mkldnn_engine)
P
Physher 已提交
375 376 377 378 379 380
        ->CreateMulPrimitive(input_x, input_y, output, ctx);
  }
}

/* XT: input x data type, YT: input y data type */
template <typename XT, typename YT>
381
class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
P
Physher 已提交
382 383
 public:
  void Compute(const ExecutionContext &ctx) const override {
384 385
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
386 387
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Mul must use CPUPlace"));
388
    platform::MKLDNNDeviceContext::tls().log_lib_version();
P
Physher 已提交
389
    auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
390
    auto &mkldnn_engine = dev_ctx.GetEngine();
P
Physher 已提交
391 392 393 394 395 396 397 398 399 400 401 402

    const Tensor *x = ctx.Input<Tensor>("X");
    const Tensor *y = ctx.Input<Tensor>("Y");
    Tensor *out = ctx.Output<Tensor>("Out");
    auto out_dims = out->dims();

    auto mul = GetMulPrimitive<XT, YT>(dev_ctx, ctx, x, y, out, mkldnn_engine);

    if (out_dims.size() != 2) {
      out->Resize(out_dims);
    }
    out->set_layout(DataLayout::kMKLDNN);
A
Adam 已提交
403 404
    out->set_format(platform::MKLDNNFormatForSize(out_dims.size(),
                                                  MKLDNNMemoryFormat::nchw));
P
Physher 已提交
405 406 407
  }
};

408 409 410 411 412 413 414 415 416
template <typename XT, typename YT>
class MulMKLDNNKernel : public framework::OpKernel<XT> {
 public:
  void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }

 protected:
  void ExecuteMatMul(const ExecutionContext &ctx,
                     const MKLDNNDeviceContext &dev_ctx,
                     const dnnl::engine &onednn_engine,
417 418 419 420 421 422 423 424
                     const platform::Place &cpu_place,
                     const Tensor *x,
                     const std::vector<int64_t> &x_dims,
                     bool trans_x,
                     const Tensor *y,
                     const std::vector<int64_t> &y_dims,
                     bool trans_y,
                     Tensor *out) const {
425
    static const std::vector<int64_t> vec_placeholder;
426 427 428 429 430 431 432 433 434
    MatMulV2MKLDNNHandler<XT> handler(onednn_engine,
                                      ctx.GetPlace(),
                                      x_dims,
                                      trans_x,
                                      y_dims,
                                      trans_y,
                                      false,
                                      vec_placeholder,
                                      vec_placeholder);
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485

    const auto src_memory_p = handler.AcquireSrcMemory(x);
    const auto weights_memory_p = handler.AcquireWeightsMemory(y);
    const auto dst_memory_p = handler.AcquireDstMemory(out);

    auto matmul_p = handler.AcquireForwardPrimitive();

    std::unordered_map<int, dnnl::memory> matmul_args = {
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};

    auto &astream = MKLDNNDeviceContext::tls().get_stream();
    matmul_p->execute(astream, matmul_args);
    astream.wait();

    out->set_layout(framework::DataLayout::kMKLDNN);
    // plain output formats are enforced inside handler
    out->set_format(platform::MKLDNNFormatForSize(
        out->dims().size(), dnnl::memory::format_tag::nchw));
  }

 private:
  void RunKernel(const ExecutionContext &ctx) const {
    const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
    const auto &onednn_engine = dev_ctx.GetEngine();

    const auto *x = ctx.Input<Tensor>("X");
    const auto *y = ctx.Input<Tensor>("Y");
    auto *out = ctx.Output<Tensor>("Out");

    int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
    int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");

    const Tensor x_matrix = x->dims().size() > 2
                                ? framework::ReshapeToMatrix(*x, x_num_col_dims)
                                : *x;
    const Tensor y_matrix = y->dims().size() > 2
                                ? framework::ReshapeToMatrix(*y, y_num_col_dims)
                                : *y;

    // adding mb dim because MatMulV2 handler needs it
    std::vector<int64_t> y_dims(3, 1);
    std::vector<int64_t> x_dims(3, 1);

    y_dims[1] = y_matrix.dims()[0];
    y_dims[2] = y_matrix.dims()[1];

    x_dims[1] = x_matrix.dims()[0];
    x_dims[2] = x_matrix.dims()[1];

486 487 488 489 490 491 492 493 494 495 496
    ExecuteMatMul(ctx,
                  dev_ctx,
                  onednn_engine,
                  ctx.GetPlace(),
                  &x_matrix,
                  x_dims,
                  false,
                  &y_matrix,
                  y_dims,
                  false,
                  out);
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
  }
};

template <typename XT, typename YT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> {
 public:
  void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }

 private:
  template <typename OT = XT>
  void RunKernel(const ExecutionContext &ctx) const {
    const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
    const auto &onednn_engine = dev_ctx.GetEngine();

    const auto *x = ctx.Input<LoDTensor>("X");
    const auto *y = ctx.Input<LoDTensor>("Y");
    const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));

    auto *dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
    auto *dy = ctx.Output<LoDTensor>(framework::GradVarName("Y"));

    int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
    int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");

    const Tensor x_matrix = x->dims().size() > 2
                                ? framework::ReshapeToMatrix(*x, x_num_col_dims)
                                : static_cast<const Tensor &>(*x);
    const Tensor y_matrix = y->dims().size() > 2
                                ? framework::ReshapeToMatrix(*y, y_num_col_dims)
                                : static_cast<const Tensor &>(*y);

    Tensor dout_matrix = *dout;
529 530
    dout_matrix.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0],
                        phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547

    // adding mb dim because MatMulV2 handler needs it
    std::vector<int64_t> x_dims(3, 1);
    std::vector<int64_t> y_dims(3, 1);
    std::vector<int64_t> dout_dims(3, 1);

    x_dims[1] = x_matrix.dims()[0];
    x_dims[2] = x_matrix.dims()[1];

    y_dims[1] = y_matrix.dims()[0];
    y_dims[2] = y_matrix.dims()[1];

    dout_dims[1] = dout_matrix.dims()[0];
    dout_dims[2] = dout_matrix.dims()[1];

    if (dx != nullptr) {
      dx->set_lod(x->lod());
548 549 550 551 552 553 554 555 556 557 558
      this->ExecuteMatMul(ctx,
                          dev_ctx,
                          onednn_engine,
                          ctx.GetPlace(),
                          &dout_matrix,
                          dout_dims,
                          false,
                          &y_matrix,
                          y_dims,
                          true,
                          static_cast<Tensor *>(dx));
559 560 561
    }
    if (dy != nullptr) {
      dy->set_lod(y->lod());
562 563 564 565 566 567 568 569 570 571 572
      this->ExecuteMatMul(ctx,
                          dev_ctx,
                          onednn_engine,
                          ctx.GetPlace(),
                          &x_matrix,
                          x_dims,
                          true,
                          &dout_matrix,
                          dout_dims,
                          false,
                          static_cast<Tensor *>(dy));
573 574 575 576
    }
  }
};

P
Physher 已提交
577 578 579 580
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
581 582 583 584 585
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    U8,
                                    ops::kMULMKLDNNINT8,
586
                                    ops::MulMKLDNNINT8Kernel<uint8_t, float>);
P
Physher 已提交
587

588 589 590 591 592
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    S8,
                                    ops::kMULMKLDNNINT8,
593 594
                                    ops::MulMKLDNNINT8Kernel<int8_t, float>);

595 596 597 598 599
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    FP32,
                                    ops::kMULMKLDNNFP32,
600 601 602
                                    ops::MulMKLDNNKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
603 604 605 606 607
    mul,
    MKLDNN,
    ::paddle::platform::CPUPlace,
    BF16,
    ops::kMULMKLDNNFP32,
608 609
    ops::MulMKLDNNKernel<paddle::platform::bfloat16,
                         paddle::platform::bfloat16>);
P
Physher 已提交
610

611 612 613
REGISTER_OP_KERNEL(mul,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
614 615 616 617 618
                   ops::MulMKLDNNINT8Kernel<uint8_t, float>,
                   ops::MulMKLDNNKernel<paddle::platform::bfloat16,
                                        paddle::platform::bfloat16>,
                   ops::MulMKLDNNKernel<float, float>);

619 620 621 622
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul_grad,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    FP32,
623 624 625 626
                                    ops::kMULMKLDNNFP32,
                                    ops::MulGradMKLDNNKernel<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
627 628 629 630 631
    mul_grad,
    MKLDNN,
    ::paddle::platform::CPUPlace,
    BF16,
    ops::kMULMKLDNNFP32,
632 633 634
    ops::MulGradMKLDNNKernel<paddle::platform::bfloat16,
                             paddle::platform::bfloat16>,
    ops::MulGradMKLDNNKernel<float, float>);