matmul_mkldnn_op.cc 25.1 KB
Newer Older
1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
16 17 18

using dnnl::memory;
using dnnl::primitive;
19 20 21 22 23 24 25 26 27 28
using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::framework::vectorize;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;

namespace {
29

30 31
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
32
static Tensor FoldOuterDims(const Tensor& input) {
33 34 35 36 37 38 39 40 41 42 43 44
  auto output = input;
  auto in_dims = input.dims();
  if (in_dims.size() == 3) {
    output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
  }
  return output;
}

// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename T>
45 46 47
static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
                                   const Tensor* input) {
  auto input_dims = vectorize(input->dims());
48 49 50 51
  if (input_dims.size() != 3) {
    return *input;
  }

52
  Tensor output;
53 54
  output.Resize({input_dims[1], input_dims[0], input_dims[2]});

55
  auto output_dims = vectorize(output.dims());
56

57 58 59 60 61 62 63
  memory::data_type input_type =
      paddle::framework::ToMKLDNNDataType(input->type());
  std::string key = paddle::platform::CreateKey(
      dev_ctx, input_dims, input->format(), input->format(), input_type);
  paddle::platform::ReorderMKLDNNHandler reorder_handler(
      output_dims, input->type(), input_type, dev_ctx, dev_ctx.GetEngine(),
      key);
64 65

  auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
66 67
      memory::format_tag::abc,
      paddle::platform::to_void_cast(input->data<T>()));
68 69 70 71 72
  auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
      &output, memory::format_tag::bac, dev_ctx.GetPlace());
  auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
                                                  reorder_dst_memory_p);

73 74
  paddle::platform::RecordEvent record_reorder(
      "int_reorder", paddle::platform::EventRole::kUniqueOp);
75

76
  auto& astream = MKLDNNDeviceContext::tls().get_stream();
77 78 79 80 81 82 83 84
  reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
  astream.wait();

  output.Resize({input_dims[1], input_dims[0] * input_dims[2]});
  return output;
}

template <typename T>
85
class MatMulMKLDNNHandler
86
    : public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
87
 public:
88
  MatMulMKLDNNHandler(const mkldnn::engine engine,
89 90
                      paddle::platform::Place cpu_place, Tensor* x,
                      bool trans_x, Tensor* y, bool trans_y, Tensor* out,
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
                      float scale)
      : paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
                                                                   cpu_place) {
    auto mat_dim_x =
        paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
    auto mat_dim_y =
        paddle::operators::math::CreateMatrixDescriptor(y->dims(), 0, trans_y);

    memory::dim x_bs = mat_dim_x.batch_size_;
    memory::dim y_bs = mat_dim_y.batch_size_;

    memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
    const memory::dim M = mat_dim_x.height_;
    const memory::dim N = mat_dim_y.width_;
    const memory::dim K = mat_dim_x.width_;

    memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
    memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
    memory::dims out_dims = {out_bs, M, N};

    memory::dims x_strides =
        !trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};

    memory::dims y_strides =
        !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
    memory::dims out_strides = memory::dims{M * N, N, 1};

    auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
    auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
    auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);

    dnnl::primitive_attr attrs;
    if (scale != 1.0f) attrs.set_output_scales(0, {scale});

    this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
126 127 128 129 130
  }

  std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
    const T* input_data = input->data<T>();
    return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
131
                                            to_void_cast<T>(input_data));
132 133 134
  }
};

135 136 137 138 139
template <typename T>
constexpr bool IsInt8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

140 141
template <typename T>
constexpr bool IsBfloat16() {
142
  return std::is_same<T, paddle::platform::bfloat16>::value;
143 144
}

145 146
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
147 148 149
static paddle::framework::DDim RowMatrixDimsFromVector(
    const paddle::framework::DDim& x_dim) {
  return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
150 151 152 153
}

// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
154 155 156
static paddle::framework::DDim ColumnMatrixDimsFromVector(
    const paddle::framework::DDim& y_dim) {
  return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
157 158
}

159 160 161 162 163 164 165
/**
 * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
 *
 * The shape would be [BatchSize, H, W] or [H, W].
 * If transposed, `H,W` will be swapped.
 */
static void ReshapeTensorToMatrixSequence(
166
    Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) {
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
  int64_t h, w;
  h = descriptor.height_;
  w = descriptor.width_;
  if (descriptor.trans_) {
    std::swap(w, h);
  }
  if (descriptor.batch_size_) {
    x->Resize({descriptor.batch_size_, h, w});
  } else {
    x->Resize({h, w});
  }
}

/**
 * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
 * Out = matmul(x, y)
 *
 * This method will first calculate X,Y matrix sequence, and then calculate
 * the out shape.
 *
 * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
 * The out = [BatchSize, H1, W2]
 *
 * If there is no batch size in `X` and `Y`, the out will be [H1, W2]
 * If any of `X` and `Y` has batch size BatchSize, the out will have the
 * BatchSize.
 */
194 195
static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out,
                                         bool trans_x, bool trans_y) {
196 197
  auto x_dim = RowMatrixDimsFromVector(x->dims());
  auto y_dim = ColumnMatrixDimsFromVector(y->dims());
198 199 200 201
  auto mat_dim_x =
      paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
  auto mat_dim_y =
      paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
202 203 204 205 206 207 208 209 210 211 212
  if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
    out->Resize({mat_dim_x.height_, mat_dim_y.width_});
  } else {
    out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
                 mat_dim_x.height_, mat_dim_y.width_});
  }

  ReshapeTensorToMatrixSequence(x, mat_dim_x);
  ReshapeTensorToMatrixSequence(y, mat_dim_y);
}

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
template <typename XT, typename YT, typename OT>
class MatMulFactory {
 public:
  void CreateAndExecute(const ExecutionContext& ctx) {
    SetDNNLEngine(ctx);
    if (IsInitialized()) {
      UpdateDataPointers(ctx);
      Execute();
      SetOutputFormat(ctx);
      return;
    }
    CreateMemories(ctx);
    CreatePrimitive(ctx);
    Execute();
    SetOutputFormat(ctx);
    SetInitialized();
  }

 private:
  struct MatMulDims {
233 234
    const memory::dims x_dims, y_dims, out_dims, x_strides, y_strides,
        out_strides;
235 236 237
  };

  void SetDNNLEngine(const ExecutionContext& ctx) {
238
    auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
239 240 241 242 243 244 245 246 247 248
    engine_ = dev_ctx.GetEngine();
  }

  template <typename T>
  dnnl::memory CreateMemory(const memory::dims& dims,
                            const memory::dims& strides, const T* data) {
    auto md = memory::desc(dims, MKLDNNGetDataType<T>(), strides);
    return dnnl::memory(md, engine_, to_void_cast(data));
  }

249 250 251 252 253 254 255
  std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
                                 const std::vector<int>& axis) {
    size_t in_rank = x.size();
    size_t axis_size = axis.size();

    auto axis_set = std::set<int>(axis.begin(), axis.end());
    PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
256
                      paddle::platform::errors::InvalidArgument(
257 258
                          "In an axis array, elements must be unique."));

259 260 261 262 263 264 265
    PADDLE_ENFORCE_EQ(in_rank, axis_size,
                      paddle::platform::errors::InvalidArgument(
                          "The input dimension's size "
                          "should be equal to the axis's size. "
                          "But received dimension is %d, "
                          "axis's size is %d",
                          in_rank, axis_size));
266 267

    PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
268
                      paddle::platform::errors::InvalidArgument(
269 270 271 272 273 274 275 276 277
                          "Axis values must be ranging from 0 to (dims - 1)."));

    std::vector<int64_t> new_x(x.size());
    for (size_t i = 0; i < x.size(); i++) {
      new_x[i] = x[axis[i]];
    }
    return new_x;
  }

278 279
  std::pair<paddle::operators::math::MatDescriptor, memory::dims>
  GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) {
280 281 282 283 284 285 286 287 288 289
    auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
    auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
    auto input_dims = ctx.Input<Tensor>(input_name)->dims();
    auto new_dims = input_dims;
    if (!shape.empty() && !axis.empty()) {
      new_dims = input_dims.reshape(shape).transpose(axis);
    }

    auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector
                                                   : ColumnMatrixDimsFromVector;
290 291 292 293
    paddle::operators::math::MatDescriptor mat_dim =
        paddle::operators::math::CreateMatrixDescriptor(
            MatrixDimsFromVector(new_dims), 0,
            ctx.Attr<bool>("transpose_" + input_name));
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317

    memory::dims strides;
    if (!shape.empty()) {
      auto shape2 = input_dims.reshape(shape);
      strides.push_back(1);
      for (auto i = shape2.size() - 1; i > 0; --i) {
        strides.insert(strides.begin(), strides.front() * shape2[i]);
      }
      strides = Transpose(strides, axis);
      if (shape.size() == 4)
        strides.erase(strides.begin());
      else if (shape.size() == 2)
        strides.insert(strides.begin(), shape[0] * shape[1]);
      mat_dim.stride_ = strides[0];
      if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
    }
    return std::make_pair(mat_dim, strides);
  }

  bool IsInputFused(const ExecutionContext& ctx) const {
    return !(ctx.Attr<std::vector<int>>("fused_reshape_X").empty() &&
             ctx.Attr<std::vector<int>>("fused_reshape_Y").empty());
  }

318 319 320 321 322 323 324 325 326 327
  bool IsOutputFused(const ExecutionContext& ctx) const {
    auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
    auto& fused_transpose_Out =
        ctx.Attr<std::vector<int>>("fused_transpose_Out");
    return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
  }

  void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx,
                                          const memory::dim N, memory::dim b,
                                          memory::dims* out_strides) const {
328 329 330
    if (!IsInt8<OT>() && !IsBfloat16<OT>() && IsOutputFused(ctx)) {
      *out_strides = {N, b * N, 1};
    }
331 332
  }

333
  MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
334
    paddle::operators::math::MatDescriptor mat_dim_x;
335 336
    memory::dims strides_x;
    std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X");
337
    paddle::operators::math::MatDescriptor mat_dim_y;
338 339
    memory::dims strides_y;
    std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
340

341 342
    auto x_bs = mat_dim_x.batch_size_;
    auto y_bs = mat_dim_y.batch_size_;
343
    PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false,
344
                      paddle::platform::errors::InvalidArgument(
345 346 347
                          "If batch sizes of X and Y are positive,"
                          "they have to be equal."));

348
    memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
349 350 351
    const memory::dim M = mat_dim_x.height_;
    const memory::dim N = mat_dim_y.width_;
    const memory::dim K = mat_dim_x.width_;
352 353

    batch_size_ = 1;
354
    if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
355 356 357
      auto& x_dims = ctx.Input<Tensor>("X")->dims();
      auto& y_dims = ctx.Input<Tensor>("Y")->dims();
      batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
358 359 360
      x_bs /= batch_size_;
      y_bs /= batch_size_;
      out_bs /= batch_size_;
361
    }
362 363 364
    memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
    memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
    memory::dims out_dims = {out_bs, M, N};
365

366 367 368
    x_offset_ = x_bs * M * K * sizeof(XT);
    y_offset_ = y_bs * K * N * sizeof(YT);
    out_offset_ = out_bs * M * N * sizeof(OT);
369 370

    // Translate transA and transB
371 372 373 374 375 376
    if (strides_x.empty())
      strides_x = !ctx.Attr<bool>("transpose_X") ? memory::dims{M * K, K, 1}
                                                 : memory::dims{M * K, 1, M};
    if (strides_y.empty())
      strides_y = !ctx.Attr<bool>("transpose_Y") ? memory::dims{N * K, N, 1}
                                                 : memory::dims{N * K, 1, K};
377 378
    memory::dims out_strides = memory::dims{M * N, N, 1};

379
    CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides);
380 381

    return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
382 383 384 385 386
  }

  void CreateMemories(const ExecutionContext& ctx) {
    auto matmul_dims = GetMatmulDims(ctx);

387 388 389 390
    x_mem_ = CreateMemory<XT>(matmul_dims.x_dims, matmul_dims.x_strides,
                              ctx.Input<Tensor>("X")->data<XT>());
    y_mem_ = CreateMemory<YT>(matmul_dims.y_dims, matmul_dims.y_strides,
                              ctx.Input<Tensor>("Y")->data<YT>());
391
    out_mem_ = CreateMemory<OT>(
392
        matmul_dims.out_dims, matmul_dims.out_strides,
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
        ctx.Output<Tensor>("Out")->mutable_data<OT>(ctx.GetPlace()));
  }

  float ComputeOutputScale(const ExecutionContext& ctx) {
    float scale_x = ctx.Attr<float>("Scale_x");
    float scale_y = ctx.Attr<float>("Scale_y");
    bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
    float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
    float alpha = ctx.Attr<float>("alpha");
    return alpha * scale_out / (scale_x * scale_y);
  }

  void CreatePrimitive(const ExecutionContext& ctx) {
    dnnl::primitive_attr attr;
    float scale_out = ComputeOutputScale(ctx);
    if (scale_out != 1.0f) {
      constexpr unsigned tensor_wide_scale = 0;
      attr.set_output_scales(tensor_wide_scale, {scale_out});
    }

    auto matmul_d = dnnl::matmul::desc(x_mem_.get_desc(), y_mem_.get_desc(),
                                       out_mem_.get_desc());
    auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine_);
    matmul_prim_ = dnnl::matmul(matmul_pd);
  }

  void Execute() {
    dnnl::stream stream(engine_);
421 422 423 424

    void* x_ptr = x_mem_.get_data_handle();
    void* y_ptr = y_mem_.get_data_handle();
    void* out_ptr = out_mem_.get_data_handle();
425
    for (uint16_t i = 0; i < batch_size_; i++) {
426 427 428 429 430 431 432 433
      x_mem_.set_data_handle(x_ptr);
      y_mem_.set_data_handle(y_ptr);
      out_mem_.set_data_handle(out_ptr);
      matmul_prim_.execute(stream, {
                                       {MKLDNN_ARG_SRC, x_mem_},
                                       {MKLDNN_ARG_WEIGHTS, y_mem_},
                                       {MKLDNN_ARG_DST, out_mem_},
                                   });
434 435 436
      x_ptr = static_cast<char*>(x_ptr) + x_offset_;
      y_ptr = static_cast<char*>(y_ptr) + y_offset_;
      out_ptr = static_cast<char*>(out_ptr) + out_offset_;
437
    }
438 439 440 441
    stream.wait();
  }

  void SetOutputFormat(const ExecutionContext& ctx) {
442
    using paddle::platform::MKLDNNFormatForSize;
443 444
    auto* out = ctx.Output<Tensor>("Out");
    auto format =
445
        MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
    out->set_format(format);
    out->set_layout(DataLayout::kMKLDNN);
  }

  void UpdateDataPointers(const ExecutionContext& ctx) {
    auto* x = ctx.Input<Tensor>("X");
    auto* y = ctx.Input<Tensor>("Y");
    auto* out = ctx.Output<Tensor>("Out");
    x_mem_.set_data_handle(to_void_cast(x->data<XT>()));
    y_mem_.set_data_handle(to_void_cast(y->data<YT>()));
    out_mem_.set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
  }

  // If initialized, x memory should've been already initialized
  bool IsInitialized() { return initialized_; }

  void SetInitialized() { initialized_ = true; }

 private:
465 466 467 468 469 470
  struct memory_offsets {
    size_t x_offset;
    size_t y_offset;
    size_t out_offset;
  };

471 472 473 474 475
  dnnl::engine engine_;
  dnnl::memory x_mem_;
  dnnl::memory y_mem_;
  dnnl::memory out_mem_;
  dnnl::matmul matmul_prim_;
476 477 478 479
  uint32_t x_offset_;
  uint32_t y_offset_;
  uint32_t out_offset_;
  uint16_t batch_size_;
480 481 482 483 484 485 486 487
  bool initialized_ = false;
};

template <typename XT, typename YT, typename OT>
static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
    const ExecutionContext& ctx) {
  const auto& out_name = ctx.OutputName("Out");
  const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
488
  const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
489 490
  std::string key = paddle::platform::CreateKey(dev_ctx, batch_size, out_name);
  key = paddle::platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506

  auto factory =
      std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
  if (factory == nullptr) {
    factory = std::make_shared<MatMulFactory<XT, YT, OT>>();
    dev_ctx.SetBlob(key, factory);
  }

  return factory;
}

// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template <typename XT, typename YT>
static void ExecuteMatMul(const ExecutionContext& ctx) {
  constexpr bool is_int8 = IsInt8<XT>();
507
  constexpr bool is_bfloat16 = IsBfloat16<XT>();
508 509
  const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
  constexpr bool fuse_relu = false;  // TODO(intel): Enable eltwise fuses
510
  if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
511
    GetPrimitiveFactory<XT, YT, float>(ctx)->CreateAndExecute(ctx);
512 513 514
  } else if (is_bfloat16) {
    GetPrimitiveFactory<XT, YT, paddle::platform::bfloat16>(ctx)
        ->CreateAndExecute(ctx);
515 516 517 518 519 520 521 522
  } else if (fuse_relu) {
    GetPrimitiveFactory<XT, YT, uint8_t>(ctx)->CreateAndExecute(ctx);
  } else {
    GetPrimitiveFactory<XT, YT, int8_t>(ctx)->CreateAndExecute(ctx);
  }
}

template <typename T>
523
class DNNLMatMulKernel : public paddle::framework::OpKernel<T> {
524
 public:
525
  void Compute(const ExecutionContext& ctx) const override {
526
    if (ctx.HasAttr("head_number")) {
527 528
      PADDLE_ENFORCE_EQ(
          ctx.Attr<int>("head_number"), 1,
529
          paddle::platform::errors::Unimplemented(
530 531 532
              "DNNL matmul doesn't support multiple heads. Expected "
              "head_number=1. But received `head_number` is %d",
              ctx.Attr<int>("head_number")));
533
    }
534
    MKLDNNDeviceContext::tls().log_lib_version();
535 536 537
    ExecuteMatMul<T, T>(ctx);
  }
};
538

539 540 541 542 543
}  // anonymous namespace

namespace paddle {
namespace operators {

544
template <typename T>
545 546 547 548 549 550 551 552
void MatMulGradMKLDNNKernel<T>::Compute(const ExecutionContext& ctx) const {
  if (ctx.HasAttr("head_number")) {
    PADDLE_ENFORCE_EQ(
        ctx.Attr<int>("head_number"), 1,
        platform::errors::Unimplemented(
            "DNNL matmul doesn't support multiple heads. Expected "
            "head_number=1. But received `head_number` is %d",
            ctx.Attr<int>("head_number")));
553
  }
554 555
  RunKernel(ctx);
}
556

557 558 559 560 561
template <typename T>
void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
    const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
    const mkldnn::engine& engine, Tensor* x, bool trans_x,
    bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y,
562
    Tensor* out) const {
563 564 565 566 567 568 569 570 571 572 573 574 575 576
  // gradient is calculated in a different way when broadcasting is used
  bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
                      out->dims().size() == 2;

  Tensor x_combined, y_combined;
  if (!need_combine) {
    x_combined = *x;
    y_combined = *y;
  } else {
    x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
                                     : FoldFirstAndLastDims<T>(dev_ctx, x);
    y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
                                     : FoldFirstAndLastDims<T>(dev_ctx, y);
  }
577

578
  float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
579

580 581
  MatMulMKLDNNHandler<T> handler(engine, ctx.GetPlace(), &x_combined, trans_x,
                                 &y_combined, trans_y, out, alpha);
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626

  const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
  const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
  const auto dst_memory_p = handler.AcquireDstMemory(out);

  auto matmul_p = handler.AcquireForwardPrimitive();

  std::unordered_map<int, dnnl::memory> matmul_args = {
      {DNNL_ARG_SRC, *src_memory_p},
      {DNNL_ARG_WEIGHTS, *weights_memory_p},
      {DNNL_ARG_DST, *dst_memory_p}};

  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
  matmul_p->execute(astream, matmul_args);
  astream.wait();

  out->set_layout(framework::DataLayout::kMKLDNN);
  out->set_format(platform::GetMKLDNNFormat(
      dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))));
}

template <typename T>
void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const {
  const auto& dev_ctx =
      ctx.template device_context<platform::MKLDNNDeviceContext>();
  const auto& onednn_engine = dev_ctx.GetEngine();

  auto x = *ctx.Input<Tensor>("X");
  auto y = *ctx.Input<Tensor>("Y");
  auto dout = *ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
  auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));

  bool transpose_x = ctx.HasAttr("transpose_X") ? ctx.Attr<bool>("transpose_X")
                                                : ctx.Attr<bool>("trans_x");
  bool transpose_y = ctx.HasAttr("transpose_Y") ? ctx.Attr<bool>("transpose_Y")
                                                : ctx.Attr<bool>("trans_y");

  ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);

  framework::DDim dx_dims;
  if (dx) {
    dx_dims = dx->dims();
    if (dx_dims != x.dims()) {
      dx->Resize(x.dims());
627
    }
628
  }
629

630 631 632 633 634
  framework::DDim dy_dims;
  if (dy) {
    dy_dims = dy->dims();
    if (dy_dims != y.dims()) {
      dy->Resize(y.dims());
635
    }
636
  }
637

638 639
  if (transpose_x && transpose_y) {
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout,
640
                            true, false, dx);
641
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
642
                            true, false, dy);
643 644
  } else if (transpose_x) {
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
645
                            &dout, true, false, dx);
646
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
647
                            &dout, false, true, dy);
648 649
  } else if (transpose_y) {
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
650
                            &y, false, true, dx);
651
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
652
                            false, true, dy);
653 654
  } else {
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
655
                            &y, true, false, dx);
656
    this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout,
657
                            false, true, dy);
658 659 660 661 662 663
  }

  if (dx) {
    if (dx_dims != x.dims()) {
      dx->Resize(dx_dims);
      dx->set_format(x.format());
664
    }
665 666 667 668 669
  }
  if (dy) {
    if (dy_dims != y.dims()) {
      dy->Resize(dy_dims);
      dy->set_format(y.format());
670 671
    }
  }
672 673 674 675
}

template class MatMulGradMKLDNNKernel<float>;
template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
676

677 678 679 680 681
}  // namespace operators
}  // namespace paddle
namespace ops = paddle::operators;

REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace,
682 683 684
                   DNNLMatMulKernel<float>,
                   DNNLMatMulKernel<paddle::platform::bfloat16>,
                   DNNLMatMulKernel<int8_t>, DNNLMatMulKernel<uint8_t>);
685 686 687 688

REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace,
                   ops::MatMulGradMKLDNNKernel<float>,
                   ops::MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);