matmul_op.cc 6.0 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
M
Markus Kliegl 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/matmul_op.h"
16 17
#include <algorithm>
#include <vector>
M
Markus Kliegl 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

namespace paddle {
namespace operators {

using framework::Tensor;

class MatMulOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* context) const override {
    PADDLE_ENFORCE(context->HasInput("X"),
                   "Input(X) of MatMulOp should not be null.");
    PADDLE_ENFORCE(context->HasInput("Y"),
                   "Input(Y) of MatMulOp should not be null.");
    PADDLE_ENFORCE(context->HasOutput("Out"),
                   "Output(Out) of MatMulOp should not be null.");

    auto dim_x = context->GetInputDim("X");
    auto dim_y = context->GetInputDim("Y");

Y
Yu Yang 已提交
40 41 42 43
    auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0,
                                     context->Attrs().Get<bool>("transpose_X"));
    auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0,
                                     context->Attrs().Get<bool>("transpose_Y"));
C
chengduoZH 已提交
44

Y
Yu Yang 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58
    PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
    PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
                   mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
    std::vector<int64_t> dim_out;
    if (mat_dim_x.batch_size_ != 0) {
      dim_out = framework::vectorize(dim_x);
      dim_out[dim_out.size() - 2] = mat_dim_x.height_;
      dim_out[dim_out.size() - 1] = mat_dim_y.width_;
    } else if (mat_dim_y.batch_size_ != 0) {
      dim_out = framework::vectorize(dim_y);
      dim_out[dim_out.size() - 2] = mat_dim_x.height_;
      dim_out[dim_out.size() - 1] = mat_dim_y.width_;
    } else {
      dim_out = {mat_dim_x.height_, mat_dim_y.width_};
M
Markus Kliegl 已提交
59 60
    }

Y
Yu Yang 已提交
61 62 63
    if (dim_x.size() == 1 && dim_out[dim_out.size() - 2] == 1) {
      std::swap(dim_out[dim_out.size() - 2], dim_out[dim_out.size() - 1]);
      dim_out.resize(dim_out.size() - 1);
M
Markus Kliegl 已提交
64 65
    }

Y
Yu Yang 已提交
66 67
    if (dim_y.size() == 1 && dim_out[dim_out.size() - 1] == 1) {
      dim_out.resize(dim_out.size() - 1);
M
Markus Kliegl 已提交
68 69
    }

Y
Yu Yang 已提交
70 71
    if (dim_out.empty()) {
      dim_out = {1};
M
Markus Kliegl 已提交
72 73 74 75 76 77 78 79
    }
    context->SetOutputDim("Out", framework::make_ddim(dim_out));
    context->ShareLoD("X", /*->*/ "Out");
  }
};

class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
80
  MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
M
Markus Kliegl 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "The first input of MatMul op");
    AddInput("Y", "The second input of MatMul op");
    AddOutput("Out", "The output of MatMul op");
    AddAttr<bool>("transpose_X",
                  R"DOC(If true, use the transpose of `X`.
        )DOC")
        .SetDefault(false);
    AddAttr<bool>("transpose_Y",
                  R"DOC(If true, use the transpose of `Y`.
        )DOC")
        .SetDefault(false);
    AddComment(R"DOC(
K
kexinzhao 已提交
94 95 96 97
MatMul Operator.


This operator is used to perform (batched) matrix multiplication
M
Markus Kliegl 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111
over the last two dimensions of the input tensors `X` and `Y`.

If a transpose flag is specified, the last two dimensions of the
tensor are transposed. If the tensor is rank-1 of shape [D], then
for `X` it is treated as [1, D] in nontransposed form and as [D, 1]
in transposed form, whereas for `Y` it is the opposite: It is treated
as [D, 1] in nontransposed form and as [1, D] in transposed form.

Examples without transpose:
- X: [K], Y: [K] => Out: [1]
- X: [K], Y: [K, N] => Out: [N]
- X: [B, M, K], Y: [K] => Out: [B, M]
- X: [M, K], Y: [B, K, N] => Out: [B, M, N]
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
C
chengduoZH 已提交
112
- X: [B, ..., M, K], Y: [B, ..., K, N] => Out: [B, ..., M, N]
M
Markus Kliegl 已提交
113 114 115

The behavior is designed to be similar to the `numpy.matmul` function.
The differences are:
C
chengduoZH 已提交
116 117
- When the rank of the input data is less than or equal to 3, it
  is similar to the `numpy.matmul` function.
C
chengduoZH 已提交
118
- When the rank of the input is greater than 3, the rank of X and
C
chengduoZH 已提交
119
  Y must be equal, and the first `rank - 2` dimensions must be equal.
M
Markus Kliegl 已提交
120 121 122
- We add `transpose_X` and `transpose_Y` flags.

Both the input `X` and `Y` can carry the LoD (Level of Details) information,
K
kexinzhao 已提交
123 124
or not. But the output only shares the LoD information with input `X`.

M
Markus Kliegl 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
)DOC");
  }
};

class MatMulOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* context) const override {
    PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
    PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
    PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null");
    auto x_dims = context->GetInputDim("X");
    auto y_dims = context->GetInputDim("Y");

    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");

    if (context->HasOutput(x_grad_name)) {
      context->SetOutputDim(x_grad_name, x_dims);
    }
    if (context->HasOutput(y_grad_name)) {
      context->SetOutputDim(y_grad_name, y_dims);
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
158
REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
159 160
                  paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
M
Markus Kliegl 已提交
161
REGISTER_OP_CPU_KERNEL(
Q
QI JUN 已提交
162 163 164 165
    matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
    matmul_grad,
    ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>);