matmul_v2_op.h 3.7 KB
Newer Older
S
ShenLiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <functional>
19
#include <utility>
S
ShenLiang 已提交
20
#include <vector>
21

S
ShenLiang 已提交
22 23 24 25
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dot_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
26 27
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
S
ShenLiang 已提交
28

29 30 31 32
// only can include the headers in paddle/phi/api dirs
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/kernels/matmul_grad_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
Z
zyfncg 已提交
33

34
#if defined(__NVCC__) || defined(__HIPCC__)
35
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
S
ShenLiang 已提交
36 37 38 39 40
#endif

namespace paddle {
namespace operators {

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static framework::Tensor FoldInitDims(const framework::Tensor& input) {
  auto output = input;
  auto in_dims = input.dims();
  if (in_dims.size() == 3) {
    output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
  }
  return output;
}

/**
 * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
 * original x_dim is returned.
 */
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
  if (x_dim.size() > 1) {
    return x_dim;
  }
60
  return phi::make_ddim({1, x_dim[0]});
61 62 63 64 65 66 67 68 69 70
}

/**
 * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
 * original y_dim is returned.
 */
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
  if (y_dim.size() > 1) {
    return y_dim;
  }
71
  return phi::make_ddim({y_dim[0], 1});
72 73 74 75 76 77 78 79 80
}

/**
 * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
 *
 * The shape would be [BatchSize, H, W] or [H, W].
 * If transposed, `H,W` will be swapped.
 */
static void ReshapeTensorIntoMatrixSequence(
81
    framework::Tensor* x, const phi::funcs::MatDescriptor& descriptor) {
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
  int64_t h, w;
  h = descriptor.height_;
  w = descriptor.width_;
  if (descriptor.trans_) {
    std::swap(w, h);
  }
  if (descriptor.batch_size_) {
    x->Resize({descriptor.batch_size_, h, w});
  } else {
    x->Resize({h, w});
  }
}

static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
                                           framework::Tensor* y,
97 98
                                           framework::Tensor* out,
                                           bool trans_x,
99 100 101
                                           bool trans_y) {
  auto x_dim = RowMatrixFromVector(x->dims());
  auto y_dim = ColumnMatrixFromVector(y->dims());
102 103
  auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
  auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
104 105 106 107
  if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
    out->Resize({mat_dim_x.height_, mat_dim_y.width_});
  } else {
    out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
108 109
                 mat_dim_x.height_,
                 mat_dim_y.width_});
110 111 112 113 114 115
  }

  ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
  ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}

S
ShenLiang 已提交
116 117
}  // namespace operators
}  // namespace paddle