matmul_v2_op.h 3.6 KB
Newer Older
S
ShenLiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <functional>
19
#include <utility>
S
ShenLiang 已提交
20
#include <vector>
21

S
ShenLiang 已提交
22 23 24
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
25 26
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
S
ShenLiang 已提交
27

28 29 30 31
// only can include the headers in paddle/phi/api dirs
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/kernels/matmul_grad_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
Z
zyfncg 已提交
32

33
#if defined(__NVCC__) || defined(__HIPCC__)
34
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
S
ShenLiang 已提交
35 36 37 38 39
#endif

namespace paddle {
namespace operators {

40 41
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
42
static phi::DenseTensor FoldInitDims(const phi::DenseTensor& input) {
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
  auto output = input;
  auto in_dims = input.dims();
  if (in_dims.size() == 3) {
    output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
  }
  return output;
}

/**
 * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
 * original x_dim is returned.
 */
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
  if (x_dim.size() > 1) {
    return x_dim;
  }
59
  return phi::make_ddim({1, x_dim[0]});
60 61 62 63 64 65 66 67 68 69
}

/**
 * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
 * original y_dim is returned.
 */
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
  if (y_dim.size() > 1) {
    return y_dim;
  }
70
  return phi::make_ddim({y_dim[0], 1});
71 72 73 74 75 76 77 78 79
}

/**
 * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
 *
 * The shape would be [BatchSize, H, W] or [H, W].
 * If transposed, `H,W` will be swapped.
 */
static void ReshapeTensorIntoMatrixSequence(
80
    phi::DenseTensor* x, const phi::funcs::MatDescriptor& descriptor) {
81 82 83 84 85 86 87 88 89 90 91 92 93
  int64_t h, w;
  h = descriptor.height_;
  w = descriptor.width_;
  if (descriptor.trans_) {
    std::swap(w, h);
  }
  if (descriptor.batch_size_) {
    x->Resize({descriptor.batch_size_, h, w});
  } else {
    x->Resize({h, w});
  }
}

94 95 96
static void ReshapeXYOutIntoMatrixSequence(phi::DenseTensor* x,
                                           phi::DenseTensor* y,
                                           phi::DenseTensor* out,
97
                                           bool trans_x,
98 99 100
                                           bool trans_y) {
  auto x_dim = RowMatrixFromVector(x->dims());
  auto y_dim = ColumnMatrixFromVector(y->dims());
101 102
  auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
  auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
103 104 105 106
  if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
    out->Resize({mat_dim_x.height_, mat_dim_y.width_});
  } else {
    out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
107 108
                 mat_dim_x.height_,
                 mat_dim_y.width_});
109 110 111 112 113 114
  }

  ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
  ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}

S
ShenLiang 已提交
115 116
}  // namespace operators
}  // namespace paddle