matrix_rank_op.cc 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
17

18 19
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/svd_helper.h"
F
From00 已提交
20
#include "paddle/phi/kernels/funcs/compare_functors.h"
21 22 23 24 25 26 27 28 29 30

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

namespace paddle {
namespace operators {
using DDim = framework::DDim;

namespace detail {
31
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
32
  auto x_vec = phi::vectorize(dim_x);
33
  if (x_vec.size() == 2) {
34
    return phi::make_ddim({1});
35 36
  }
  x_vec.erase(x_vec.end() - 2, x_vec.end());
37
  return phi::make_ddim(x_vec);
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
}
}  // namespace detail

class MatrixRankeOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "MatrixRank");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixRank");
    auto dim_x = ctx->GetInputDim("X");
    PADDLE_ENFORCE_GE(dim_x.size(), 2,
                      platform::errors::InvalidArgument(
                          "The dims of input must be greater than 2"));

    bool hermitian = ctx->Attrs().Get<bool>("hermitian");
    if (hermitian) {
      int rows = dim_x[dim_x.size() - 2];
      int cols = dim_x[dim_x.size() - 1];
      PADDLE_ENFORCE_EQ(rows, cols,
                        platform::errors::InvalidArgument(
                            "if hermitian == true, matrix should be n*n"));
    }

62 63
    DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
    if (ctx->HasInput("TolTensor")) {
64 65 66 67 68 69 70 71 72
      auto dim_tol = ctx->GetInputDim("TolTensor");
      if (dim_x_batch == dim_tol) {
        ctx->SetOutputDim("Out", dim_x_batch);
      } else {
        int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
        int axis = std::abs(dim_x_batch.size() - dim_tol.size());
        std::vector<int> x_batch_dims_array(max_dim);
        std::vector<int> tol_dims_array(max_dim);
        std::vector<int> out_dims_array(max_dim);
73 74 75
        phi::funcs::GetBroadcastDimsArrays(
            dim_x_batch, dim_tol, x_batch_dims_array.data(),
            tol_dims_array.data(), out_dims_array.data(), max_dim, axis);
76
        ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array));
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
      }
    } else {
      ctx->SetOutputDim("Out", dim_x_batch);
    }
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
    return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
  }
};

class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor), The input tensor of matrix_rank op.");
98 99 100
    AddInput("TolTensor",
             "(optional) Tol tensor, shape is same as X batch or can broadcast "
             "with X batch.")
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        .AsDispensable();
    AddOutput("Out", "(Tensor), The output tensor of matrix_rank op.");
    AddAttr<float>("tol", "(float, optional). tol").SetDefault(0.0f);
    AddAttr<bool>("use_default_tol",
                  "represent whether user input TolTensor/tol, if input "
                  "TolTensor/tol use_default_tol=true, otherwise "
                  "use_default_tol=false")
        .SetDefault(true);
    AddAttr<bool>("hermitian", "(bool, optional). whether is hermitian matrix")
        .SetDefault(false);
    AddComment(R"DOC(MatrixRank Operator.
    This operator is used to perform MatrixRank operation for batched matrics.
    $$out = matrix_rank(X, tol, hermitian)$$
    )DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(matrix_rank, ops::MatrixRankeOp, ops::MatrixRankeOpMaker);