matmul_v2_mkldnn_op.cc 26.9 KB
Newer Older
S
Sławomir Siwek 已提交
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15 16 17 18

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
19

20
namespace {
21
using dnnl::memory;
22
using paddle::framework::ExecutionContext;
23
using paddle::platform::MatMulV2MKLDNNHandler;
24
using phi::OneDNNContext;
25
using phi::vectorize;
26
using phi::funcs::OneDNNGetDataType;
27
using Tensor = phi::DenseTensor;
28
using paddle::framework::GradVarName;
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static Tensor FoldOuterDims(const Tensor &input) {
  auto output = input;
  auto in_dims = input.dims();
  if (in_dims.size() == 3) {
    output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
  }
  return output;
}

// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename T>
45
static Tensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx,
46 47 48 49 50 51 52 53 54 55 56
                                   const Tensor *input) {
  auto input_dims = vectorize(input->dims());
  if (input_dims.size() != 3) {
    return *input;
  }

  Tensor output;
  output.Resize({input_dims[1], input_dims[0], input_dims[2]});

  auto output_dims = vectorize(output.dims());

57
  memory::data_type input_type = phi::funcs::ToOneDNNDataType(input->dtype());
58 59
  phi::funcs::ReorderOneDNNHandler reorder_handler(
      output_dims, input->dtype(), input_type, dev_ctx.GetEngine());
60 61

  auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
62
      memory::format_tag::abc, phi::funcs::to_void_cast(input->data<T>()));
63 64 65 66 67
  auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
      &output, memory::format_tag::bac, dev_ctx.GetPlace());
  auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
                                                  reorder_dst_memory_p);

68
  auto &astream = OneDNNContext::tls().get_stream();
69 70 71 72 73 74 75
  reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
  astream.wait();

  output.Resize({input_dims[1], input_dims[0] * input_dims[2]});
  return output;
}

J
Jacek Czaja 已提交
76 77 78
phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
  auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
  auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
79
  auto input_dims = ctx.Input<phi::DenseTensor>(input_name)->dims();
J
Jacek Czaja 已提交
80 81 82 83 84 85
  if (!shape.empty() && !axis.empty()) {
    return input_dims.reshape(shape).transpose(axis);
  }
  return input_dims;
}

86 87
template <typename XT, typename YT, typename OT>
class MatMulMKLDNNHandler
88
    : public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
89 90 91 92 93 94 95 96 97
 public:
  MatMulMKLDNNHandler(const dnnl::engine engine,
                      paddle::platform::Place cpu_place,
                      Tensor *x,
                      bool trans_x,
                      Tensor *y,
                      bool trans_y,
                      Tensor *out,
                      float scale)
98 99
      : phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
                                                              cpu_place) {
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x);
    auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y);

    memory::dim x_bs = mat_dim_x.batch_size_;
    memory::dim y_bs = mat_dim_y.batch_size_;

    memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
    const memory::dim M = mat_dim_x.height_;
    const memory::dim N = mat_dim_y.width_;
    const memory::dim K = mat_dim_x.width_;

    memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
    memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
    memory::dims out_dims = {out_bs, M, N};

    memory::dims x_strides =
        !trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};

    memory::dims y_strides =
        !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
    memory::dims out_strides = memory::dims{M * N, N, 1};

122 123 124
    auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides);
    auto y_md = memory::desc(y_dims, OneDNNGetDataType<YT>(), y_strides);
    auto out_md = memory::desc(out_dims, OneDNNGetDataType<OT>(), out_strides);
125 126 127 128 129 130 131 132 133

    dnnl::primitive_attr attrs;
    if (scale != 1.0f) attrs.set_output_scales(0, {scale});

    this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
  }

  std::shared_ptr<memory> AcquireWeightsMemory(const Tensor *input) {
    const YT *input_data = input->data<YT>();
134 135 136
    return this->AcquireMemoryFromPrimitive(
        this->fwd_pd_->weights_desc(),
        phi::funcs::to_void_cast<YT>(input_data));
137 138 139
  }

 public:
140 141 142
  void Execute(const phi::DenseTensor *x,
               const phi::DenseTensor *y,
               phi::DenseTensor *out) {
143 144 145 146 147 148 149 150 151 152 153
    const auto src_memory_p = this->AcquireSrcMemory(x);
    const auto weights_memory_p = this->AcquireWeightsMemory(y);
    const auto dst_memory_p = this->AcquireDstMemory(out);

    auto matmul_p = this->AcquireForwardPrimitive();

    std::unordered_map<int, dnnl::memory> matmul_args = {
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};

154
    auto &astream = OneDNNContext::tls().get_stream();
155 156 157 158 159

    // Simulate batch matmul by processing in loop
    void *x_ptr = src_memory_p->get_data_handle();
    void *y_ptr = weights_memory_p->get_data_handle();
    void *out_ptr = dst_memory_p->get_data_handle();
S
Sławomir Siwek 已提交
160 161
    auto offsets = std::make_tuple(x_offset_, y_offset_, out_offset_);
    for (uint16_t i = 0; i < batch_size_; ++i) {
162 163 164
      src_memory_p->set_data_handle(x_ptr);
      weights_memory_p->set_data_handle(y_ptr);
      dst_memory_p->set_data_handle(out_ptr);
165
      matmul_p->execute(astream, matmul_args);
166 167 168 169 170 171
      x_ptr = static_cast<char *>(x_ptr) + std::get<0>(offsets);
      y_ptr = static_cast<char *>(y_ptr) + std::get<1>(offsets);
      out_ptr = static_cast<char *>(out_ptr) + std::get<2>(offsets);
    }
    astream.wait();

172
    out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
173 174
  }

175
  std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor *output) {
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
    // We cannot use base AcquireDstMemory as it makes an allocation request
    // base on DST memory primitive size. This is fine in general, but in MatMul
    // we have primitive that covers only one batch of Data and then shift
    // pointer for every new batch. Hence Tensor size is bigger that dst memory
    // primitive size. So would we request less memory that is there and it
    // triggers an
    // assertion.  So as there is no 'any' format here we can leave default size
    // of Tensor as computed in ComputeInferShape
    OT *ptr = output->mutable_data<OT>(this->place_);
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
  }

 private:
  uint32_t x_offset_;
  uint32_t y_offset_;
  uint32_t out_offset_;
  uint16_t batch_size_;
};

/**
 * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
 *
 * The shape would be [BatchSize, H, W] or [H, W].
 * If transposed, `H,W` will be swapped.
 */
static void ReshapeTensorToMatrixSequence(
    Tensor *x, const phi::funcs::MatDescriptor &descriptor) {
  int64_t h, w;
  h = descriptor.height_;
  w = descriptor.width_;
  if (descriptor.trans_) {
    std::swap(w, h);
  }
  if (descriptor.batch_size_) {
    x->Resize({descriptor.batch_size_, h, w});
  } else {
    x->Resize({h, w});
  }
}

/**
 * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
 * Out = matmul(x, y)
 *
 * This method will first calculate X,Y matrix sequence, and then calculate
 * the out shape.
 *
 * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
 * The out = [BatchSize, H1, W2]
 *
 * If there is no batch size in `X` and `Y`, the out will be [H1, W2]
 * If any of `X` and `Y` has batch size BatchSize, the out will have the
 * BatchSize.
 */
static void ReshapeXYOutToMatrixSequence(
    Tensor *x, Tensor *y, Tensor *out, bool trans_x, bool trans_y) {
232 233
  auto x_dim = phi::funcs::RowMatrixDimsFromVector(x->dims());
  auto y_dim = phi::funcs::ColumnMatrixDimsFromVector(y->dims());
234 235 236 237 238 239 240 241 242 243 244 245 246 247
  auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
  auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
  if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
    out->Resize({mat_dim_x.height_, mat_dim_y.width_});
  } else {
    out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
                 mat_dim_x.height_,
                 mat_dim_y.width_});
  }

  ReshapeTensorToMatrixSequence(x, mat_dim_x);
  ReshapeTensorToMatrixSequence(y, mat_dim_y);
}

S
Sławomir Siwek 已提交
248 249
std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
                               const std::vector<int> &axis) {
250 251 252 253
  size_t in_rank = x.size();
  size_t axis_size = axis.size();

  auto axis_set = std::set<int>(axis.begin(), axis.end());
254 255
  PADDLE_ENFORCE_EQ(axis_set.size(),
                    axis_size,
256 257 258
                    paddle::platform::errors::InvalidArgument(
                        "In an axis array, elements must be unique."));

259 260
  PADDLE_ENFORCE_EQ(in_rank,
                    axis_size,
261 262 263 264 265
                    paddle::platform::errors::InvalidArgument(
                        "The input dimension's size "
                        "should be equal to the axis's size. "
                        "But received dimension is %d, "
                        "axis's size is %d",
266 267
                        in_rank,
                        axis_size));
268

269 270
  PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
                    axis_size,
271 272 273 274 275 276 277 278 279 280
                    paddle::platform::errors::InvalidArgument(
                        "Axis values must be ranging from 0 to (dims - 1)."));

  std::vector<int64_t> new_x(x.size());
  for (size_t i = 0; i < x.size(); i++) {
    new_x[i] = x[axis[i]];
  }
  return new_x;
}

281
std::vector<int64_t> GetInputStrides(const ExecutionContext &ctx,
282 283 284
                                     const std::string input_name) {
  auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
  auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
285
  auto input_dims = ctx.Input<phi::DenseTensor>(input_name)->dims();
286 287 288 289 290
  auto new_dims = input_dims;
  if (!shape.empty() && !axis.empty()) {
    new_dims = input_dims.reshape(shape).transpose(axis);
  }

291 292 293
  auto &MatrixDimsFromVector = input_name == "X"
                                   ? phi::funcs::RowMatrixDimsFromVector
                                   : phi::funcs::ColumnMatrixDimsFromVector;
294
  phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
295 296
      MatrixDimsFromVector(new_dims),
      0,
297 298 299 300
      ctx.HasAttr("trans_x")
          ? ctx.Attr<bool>(std::string("trans_") +
                           static_cast<char>(std::tolower(input_name[0])))
          : ctx.Attr<bool>(std::string("transpose_") + input_name[0]));
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319

  std::vector<int64_t> strides;
  if (!shape.empty()) {
    auto shape2 = input_dims.reshape(shape);
    strides.push_back(1);
    for (auto i = shape2.size() - 1; i > 0; --i) {
      strides.insert(strides.begin(),
                     strides.front() * static_cast<int64_t>(shape2[i]));
    }
    strides = Transpose(strides, axis);
    if (shape.size() == 2)
      strides.insert(strides.begin(),
                     static_cast<int64_t>(shape[0] * shape[1]));
    mat_dim.stride_ = strides[0];
    if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
  }
  return strides;
}

320 321 322
bool IsOutputFused(const ExecutionContext &ctx) {
  auto &fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
  auto &fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
323 324 325
  return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}

326
template <typename T, typename T_out>
327
void ExecuteMatMulV2(const ExecutionContext &ctx,
328
                     const dnnl::engine onednn_engine,
329 330
                     const Tensor *x,
                     const std::vector<int64_t> &x_dims,
331
                     bool trans_x,
332 333
                     const Tensor *y,
                     const std::vector<int64_t> &y_dims,
334
                     bool trans_y,
335
                     Tensor *out) {
336 337
  std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
  std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
338 339 340 341 342 343 344 345 346 347
  MatMulV2MKLDNNHandler<T, T, T_out> handler(ctx,
                                             onednn_engine,
                                             ctx.GetPlace(),
                                             x_dims,
                                             trans_x,
                                             y_dims,
                                             trans_y,
                                             IsOutputFused(ctx),
                                             x_strides_override,
                                             y_strides_override);
348

349 350 351
  const auto src_memory_p = handler.AcquireSrcMemory(x);
  const auto weights_memory_p = handler.AcquireWeightsMemory(y);
  const auto dst_memory_p = handler.AcquireDstMemory(out);
352

353
  auto matmul_p = handler.AcquireForwardPrimitive();
354

355 356 357 358
  std::unordered_map<int, memory> matmul_args = {
      {DNNL_ARG_SRC, *src_memory_p},
      {DNNL_ARG_WEIGHTS, *weights_memory_p},
      {DNNL_ARG_DST, *dst_memory_p}};
359

360
  if (ctx.HasInput("ResidualData")) {
361
    auto *residual_data = ctx.Input<phi::DenseTensor>("ResidualData");
362 363 364 365 366
    const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
    matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
                        *residual_data_memory_p});
  }

367
  auto &astream = OneDNNContext::tls().get_stream();
368 369
  matmul_p->execute(astream, matmul_args);
  astream.wait();
370 371 372

  // TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
  // permute
373
  if (IsOutputFused(ctx) && !phi::funcs::is_int8<T_out>()) {
374 375
    auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
    auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
376
    out->set_mem_desc(permuted_md.reshape(vectorize<int64_t>(out->dims())));
377 378
  } else {
    out->set_mem_desc(
379
        dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
380
  }
381 382 383
}

template <typename T>
384
class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
385
 public:
386 387 388 389 390 391 392 393 394 395
  void Compute(const ExecutionContext &ctx) const override {
    if (ctx.HasAttr("head_number")) {
      PADDLE_ENFORCE_EQ(
          ctx.Attr<int>("head_number"),
          1,
          paddle::platform::errors::Unimplemented(
              "oneDNN matmul doesn't support multiple heads. Expected "
              "head_number=1. But received `head_number` is %d",
              ctx.Attr<int>("head_number")));
    }
396 397
    constexpr bool is_int8 = phi::funcs::is_int8<T>();
    constexpr bool is_bfloat16 = phi::funcs::is_bfloat16<T>();
398 399 400 401
    const bool force_fp32_output = ctx.HasAttr("force_fp32_output")
                                       ? ctx.Attr<bool>("force_fp32_output")
                                       : false;
    constexpr bool fuse_relu = false;  // TODO(intel): Enable eltwise fuses
402

403
    const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    const auto &onednn_engine = dev_ctx.GetEngine();

    auto *x = ctx.Input<phi::DenseTensor>("X");
    auto *y = ctx.Input<phi::DenseTensor>("Y");
    auto *out = ctx.Output<phi::DenseTensor>("Out");
    bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr<bool>("trans_x")
                                          : ctx.Attr<bool>("transpose_X");
    bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr<bool>("trans_y")
                                          : ctx.Attr<bool>("transpose_Y");

    auto x_dims = vectorize(GetDimForInput(ctx, "X"));
    auto y_dims = vectorize(GetDimForInput(ctx, "Y"));

    int ndims = std::max(x_dims.size(), y_dims.size());
    ndims = std::max(ndims, 3);

    std::vector<int64_t> x_bd_dims(ndims, 1);
    std::vector<int64_t> y_bd_dims(ndims, 1);

    CalculateMatrixDims(ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, out);

425
    if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
426 427 428 429 430 431 432 433 434
      ExecuteMatMulV2<T, float>(ctx,
                                onednn_engine,
                                x,
                                x_bd_dims,
                                trans_x,
                                y,
                                y_bd_dims,
                                trans_y,
                                out);
435
    } else if (is_bfloat16) {
436 437 438 439 440 441 442 443 444
      ExecuteMatMulV2<T, paddle::platform::bfloat16>(ctx,
                                                     onednn_engine,
                                                     x,
                                                     x_bd_dims,
                                                     trans_x,
                                                     y,
                                                     y_bd_dims,
                                                     trans_y,
                                                     out);
445
    } else if (fuse_relu) {
446 447 448 449 450 451 452 453 454
      ExecuteMatMulV2<T, uint8_t>(ctx,
                                  onednn_engine,
                                  x,
                                  x_bd_dims,
                                  trans_x,
                                  y,
                                  y_bd_dims,
                                  trans_y,
                                  out);
455
    } else {
456 457 458 459 460 461 462 463 464
      ExecuteMatMulV2<T, int8_t>(ctx,
                                 onednn_engine,
                                 x,
                                 x_bd_dims,
                                 trans_x,
                                 y,
                                 y_bd_dims,
                                 trans_y,
                                 out);
465 466
    }
  }
467

468
 private:
469 470 471 472 473 474
  void CalculateMatrixDims(const ExecutionContext &ctx,
                           const std::vector<int64_t> &x_dims,
                           const std::vector<int64_t> &y_dims,
                           std::vector<int64_t> *x_bd_dims,
                           std::vector<int64_t> *y_bd_dims,
                           Tensor *out) const {
475
    if (x_dims.size() == 1) {
476
      (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
477
    } else if (x_dims.size() == 2) {
478 479
      (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
      (*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
480 481
    } else {
      for (size_t i = 0; i < x_dims.size(); ++i) {
482
        (*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
483 484 485
      }
    }
    if (y_dims.size() == 1) {
486
      (*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
487
    } else if (y_dims.size() == 2) {
488 489
      (*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
      (*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
490 491
    } else {
      for (size_t i = 0; i < y_dims.size(); ++i) {
492
        (*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
493 494 495
      }
    }

496
    if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) {
497
      auto out_dims = vectorize(out->dims());
498
      for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
499
        PADDLE_ENFORCE_EQ(
500 501
            (*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
                (*y_bd_dims)[i] == 1,
502 503 504 505 506
            true,
            paddle::platform::errors::InvalidArgument(
                "Tensor dimensions are incorrect for broadcasting."
                "Dimensions in X and Y must be same or equal to 1, but "
                "received x_dim[%d]=%d and y_dims[%d]= %d",
507 508 509 510
                i,
                (*x_bd_dims)[i],
                i,
                (*y_bd_dims)[i]));
511
        (out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
512
      }
513
      out->Resize(phi::make_ddim((out_dims)));
514 515
    }
  }
516
};
517

518 519 520 521 522 523 524 525 526 527 528 529 530
template <typename T>
class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const ExecutionContext &ctx) const override {
    if (ctx.HasAttr("head_number")) {
      PADDLE_ENFORCE_EQ(
          ctx.Attr<int>("head_number"),
          1,
          paddle::platform::errors::Unimplemented(
              "oneDNN matmul doesn't support multiple heads. Expected "
              "head_number=1. But received `head_number` is %d",
              ctx.Attr<int>("head_number")));
    }
531

532
    const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
533
    const auto &onednn_engine = dev_ctx.GetEngine();
534

535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
    auto x = *ctx.Input<phi::DenseTensor>("X");
    auto y = *ctx.Input<phi::DenseTensor>("Y");
    auto dout =
        *ctx.Input<phi::DenseTensor>(paddle::framework::GradVarName("Out"));
    auto *dx =
        ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("X"));
    auto *dy =
        ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("Y"));

    bool transpose_x = ctx.HasAttr("transpose_X")
                           ? ctx.Attr<bool>("transpose_X")
                           : ctx.Attr<bool>("trans_x");
    bool transpose_y = ctx.HasAttr("transpose_Y")
                           ? ctx.Attr<bool>("transpose_Y")
                           : ctx.Attr<bool>("trans_y");

    ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);

    paddle::framework::DDim dx_dims;
    if (dx) {
      dx_dims = dx->dims();
      if (dx_dims != x.dims()) {
        dx->Resize(x.dims());
      }
    }
560

561 562 563 564 565 566 567
    paddle::framework::DDim dy_dims;
    if (dy) {
      dy_dims = dy->dims();
      if (dy_dims != y.dims()) {
        dy->Resize(y.dims());
      }
    }
568

569 570 571 572 573 574 575
    if (transpose_x && transpose_y) {
      this->ExecuteMatMulGrad(
          ctx, dev_ctx, onednn_engine, &y, true, true, &dout, true, false, dx);
      this->ExecuteMatMulGrad(
          ctx, dev_ctx, onednn_engine, &dout, true, true, &x, true, false, dy);
    } else if (transpose_x) {
      this->ExecuteMatMulGrad(ctx,
576 577
                              dev_ctx,
                              onednn_engine,
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
                              &y,
                              false,
                              false,
                              &dout,
                              true,
                              false,
                              dx);
      this->ExecuteMatMulGrad(ctx,
                              dev_ctx,
                              onednn_engine,
                              &x,
                              false,
                              false,
                              &dout,
                              false,
                              true,
                              dy);
    } else if (transpose_y) {
      this->ExecuteMatMulGrad(ctx,
                              dev_ctx,
                              onednn_engine,
                              &dout,
                              false,
                              false,
                              &y,
                              false,
                              true,
                              dx);
      this->ExecuteMatMulGrad(
          ctx, dev_ctx, onednn_engine, &dout, true, true, &x, false, true, dy);
    } else {
      this->ExecuteMatMulGrad(ctx,
                              dev_ctx,
                              onednn_engine,
                              &dout,
                              false,
                              false,
                              &y,
                              true,
                              false,
                              dx);
      this->ExecuteMatMulGrad(
          ctx, dev_ctx, onednn_engine, &x, true, true, &dout, false, true, dy);
    }
622

623 624 625 626 627 628 629 630 631 632
    if (dx) {
      if (dx_dims != x.dims()) {
        dx->Resize(dx_dims);
        dx->set_mem_desc(x.mem_desc());
      }
    }
    if (dy) {
      if (dy_dims != y.dims()) {
        dy->Resize(dy_dims);
        dy->set_mem_desc(y.mem_desc());
633 634 635 636
      }
    }
  }

637 638
 private:
  void ExecuteMatMulGrad(const ExecutionContext &ctx,
639
                         const OneDNNContext &dev_ctx,
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
                         const dnnl::engine &engine,
                         phi::DenseTensor *x,
                         bool trans_x,
                         bool is_fold_init_dims_x,
                         phi::DenseTensor *y,
                         bool trans_y,
                         bool is_fold_init_dims_y,
                         phi::DenseTensor *out) const {
    // gradient is calculated in a different way when broadcasting is used
    bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
                        out->dims().size() == 2;

    Tensor x_combined, y_combined;
    if (!need_combine) {
      x_combined = *x;
      y_combined = *y;
    } else {
      x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
                                       : FoldFirstAndLastDims<T>(dev_ctx, x);
      y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
                                       : FoldFirstAndLastDims<T>(dev_ctx, y);
    }
662

663
    float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
664

665 666 667 668 669 670 671 672
    MatMulMKLDNNHandler<T, T, T> handler(engine,
                                         ctx.GetPlace(),
                                         &x_combined,
                                         trans_x,
                                         &y_combined,
                                         trans_y,
                                         out,
                                         alpha);
673

674 675 676
    const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
    const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
    const auto dst_memory_p = handler.AcquireDstMemory(out);
677

678
    auto matmul_p = handler.AcquireForwardPrimitive();
679

680 681 682 683
    std::unordered_map<int, dnnl::memory> matmul_args = {
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};
684

685
    auto &astream = OneDNNContext::tls().get_stream();
686 687
    matmul_p->execute(astream, matmul_args);
    astream.wait();
688

689 690
    out->set_mem_desc(
        dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
691
  }
692
};
693

694
}  // anonymous namespace
695

696 697 698
REGISTER_OP_KERNEL(matmul,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
699 700 701 702
                   MatMulMKLDNNKernel<float>,
                   MatMulMKLDNNKernel<paddle::platform::bfloat16>,
                   MatMulMKLDNNKernel<int8_t>,
                   MatMulMKLDNNKernel<uint8_t>);
703 704 705 706

REGISTER_OP_KERNEL(matmul_grad,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
707 708
                   MatMulGradMKLDNNKernel<float>,
                   MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);