matmul_v2_op.h 4.2 KB
Newer Older
1
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
S
ShenLiang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <functional>
19
#include <utility>
S
ShenLiang 已提交
20
#include <vector>
21

S
ShenLiang 已提交
22 23 24
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
25 26
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
S
ShenLiang 已提交
27

28
// only can include the headers in paddle/phi/api dirs
29
#include "paddle/fluid/framework/phi_utils.h"
30 31
#include "paddle/phi/kernels/matmul_grad_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
Z
zyfncg 已提交
32

33
#if defined(__NVCC__) || defined(__HIPCC__)
34
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
S
ShenLiang 已提交
35 36 37 38 39
#endif

namespace paddle {
namespace operators {

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
class MatMulV2Op : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;

 protected:
  phi::KernelKey GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;

  phi::KernelKey GetKernelTypeForVar(
      const std::string& var_name,
      const phi::DenseTensor& tensor,
      const phi::KernelKey& expected_kernel_type) const override;
};

class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() final;

 protected:
  virtual void Apply() {}
};

63 64
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
65
static phi::DenseTensor FoldInitDims(const phi::DenseTensor& input) {
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
  auto output = input;
  auto in_dims = input.dims();
  if (in_dims.size() == 3) {
    output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
  }
  return output;
}

/**
 * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
 * original x_dim is returned.
 */
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
  if (x_dim.size() > 1) {
    return x_dim;
  }
82
  return phi::make_ddim({1, x_dim[0]});
83 84 85 86 87 88 89 90 91 92
}

/**
 * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
 * original y_dim is returned.
 */
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
  if (y_dim.size() > 1) {
    return y_dim;
  }
93
  return phi::make_ddim({y_dim[0], 1});
94 95 96 97 98 99 100 101 102
}

/**
 * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
 *
 * The shape would be [BatchSize, H, W] or [H, W].
 * If transposed, `H,W` will be swapped.
 */
static void ReshapeTensorIntoMatrixSequence(
103
    phi::DenseTensor* x, const phi::funcs::MatDescriptor& descriptor) {
104 105 106 107 108 109 110 111 112 113 114 115 116
  int64_t h, w;
  h = descriptor.height_;
  w = descriptor.width_;
  if (descriptor.trans_) {
    std::swap(w, h);
  }
  if (descriptor.batch_size_) {
    x->Resize({descriptor.batch_size_, h, w});
  } else {
    x->Resize({h, w});
  }
}

117 118 119
static void ReshapeXYOutIntoMatrixSequence(phi::DenseTensor* x,
                                           phi::DenseTensor* y,
                                           phi::DenseTensor* out,
120
                                           bool trans_x,
121 122 123
                                           bool trans_y) {
  auto x_dim = RowMatrixFromVector(x->dims());
  auto y_dim = ColumnMatrixFromVector(y->dims());
124 125
  auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
  auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
126 127 128 129
  if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
    out->Resize({mat_dim_x.height_, mat_dim_y.width_});
  } else {
    out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
130 131
                 mat_dim_x.height_,
                 mat_dim_y.width_});
132 133 134 135 136 137
  }

  ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
  ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}

S
ShenLiang 已提交
138 139
}  // namespace operators
}  // namespace paddle