From a2aa0087ba1face575e34badf1f5c9fd637fc260 Mon Sep 17 00:00:00 2001 From: Sanbu <96160062+sanbuphy@users.noreply.github.com> Date: Mon, 17 Apr 2023 17:14:31 +0800 Subject: [PATCH] Support static graph code-gen for matrix_rank (#52659) --- paddle/fluid/operators/matrix_rank_op.cc | 123 ----------------------- paddle/phi/api/yaml/op_compat.yaml | 7 ++ paddle/phi/api/yaml/static_ops.yaml | 12 +++ paddle/phi/infermeta/binary.cc | 14 ++- paddle/phi/infermeta/binary.h | 6 ++ paddle/phi/ops/compat/matrix_rank_sig.cc | 6 ++ 6 files changed, 44 insertions(+), 124 deletions(-) delete mode 100644 paddle/fluid/operators/matrix_rank_op.cc diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc deleted file mode 100644 index 16ca2cf09ec..00000000000 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/phi/kernels/funcs/compare_functors.h" - -namespace paddle { -namespace operators { -using DDim = framework::DDim; - -namespace detail { -static DDim CheckAndGetOutputDim(const DDim& dim_x) { - auto x_vec = phi::vectorize(dim_x); - if (x_vec.size() == 2) { - return phi::make_ddim({1}); - } - x_vec.erase(x_vec.end() - 2, x_vec.end()); - return phi::make_ddim(x_vec); -} -} // namespace detail - -class MatrixRankOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "MatrixRank"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixRank"); - auto dim_x = ctx->GetInputDim("X"); - PADDLE_ENFORCE_GE(dim_x.size(), - 2, - platform::errors::InvalidArgument( - "The dims of input must be greater than 2")); - - bool hermitian = ctx->Attrs().Get("hermitian"); - if (hermitian) { - int rows = dim_x[dim_x.size() - 2]; - int cols = dim_x[dim_x.size() - 1]; - PADDLE_ENFORCE_EQ(rows, - cols, - platform::errors::InvalidArgument( - "if hermitian == true, matrix should be n*n")); - } - - DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); - if (ctx->HasInput("TolTensor")) { - auto dim_tol = ctx->GetInputDim("TolTensor"); - if (dim_x_batch == dim_tol) { - ctx->SetOutputDim("Out", dim_x_batch); - } else { - int max_dim = std::max(dim_x_batch.size(), dim_tol.size()); - int axis = std::abs(dim_x_batch.size() - dim_tol.size()); - std::vector x_batch_dims_array(max_dim); - std::vector tol_dims_array(max_dim); - std::vector out_dims_array(max_dim); - phi::funcs::GetBroadcastDimsArrays(dim_x_batch, - dim_tol, - x_batch_dims_array.data(), - tol_dims_array.data(), - out_dims_array.data(), - max_dim, - axis); - ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array)); - } - } else { - ctx->SetOutputDim("Out", dim_x_batch); - } - ctx->ShareLoD("X", /*->*/ "Out"); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(data_type, ctx.GetPlace()); - } -}; - -class MatrixRankOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of matrix_rank op."); - AddInput("TolTensor", - "(optional) Tol tensor, shape is same as X batch or can broadcast " - "with X batch.") - .AsDispensable(); - AddOutput("Out", "(Tensor), The output tensor of matrix_rank op."); - AddAttr("tol", "(float, optional). tol").SetDefault(0.0f); - AddAttr("use_default_tol", - "represent whether user input TolTensor/tol, if input " - "TolTensor/tol use_default_tol=true, otherwise " - "use_default_tol=false") - .SetDefault(true); - AddAttr("hermitian", "(bool, optional). whether is hermitian matrix") - .SetDefault(false); - AddComment(R"DOC(MatrixRank Operator. - This operator is used to perform MatrixRank operation for batched matrics. - $$out = matrix_rank(X, tol, hermitian)$$ - )DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(matrix_rank, ops::MatrixRankOp, ops::MatrixRankOpMaker); diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 38e39b6b25b..077d746c22f 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1391,6 +1391,13 @@ outputs : out : Out +- op : matrix_rank + inputs : + {x : X, tol_tensor : TolTensor} + outputs : + out : Out + manual_signature : [matrix_rank] + - op : max (reduce_max) backward : max_grad (reduce_max_grad) inputs: diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 46c107ad2d6..9a90addfd0d 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -217,6 +217,18 @@ data_transform : skip_transform : start, stop, number +- op : matrix_rank + args : (Tensor x, Tensor tol_tensor, float tol=0.0f, bool hermitian=false, bool use_default_tol=true) + output : Tensor(out) + infer_meta : + func : MatrixRankStaticInferMeta + param : [x, tol_tensor, hermitian, use_default_tol] + optional : tol_tensor + kernel : + func : matrix_rank {dense -> dense}, + matrix_rank_tol {dense, dense -> dense} + data_type : x + - op : not_equal args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false) output : Tensor(out) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 054563559fb..85958f285e0 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -18,12 +18,12 @@ limitations under the License. */ #include #include "glog/logging.h" - #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/common_shape.h" @@ -2199,6 +2199,18 @@ void MatrixNMSInferMeta(const MetaTensor& bboxes, } } +void MatrixRankStaticInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + if (atol_tensor) { + MatrixRankTolInferMeta(x, atol_tensor, use_default_tol, hermitian, out); + } else { + MatrixRankInferMeta(x, hermitian, use_default_tol, out); + } +} + void MatrixRankTolInferMeta(const MetaTensor& x, const MetaTensor& atol_tensor, bool use_default_tol, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 90e668b8c05..c071b3319bb 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -353,6 +353,12 @@ void MatrixNMSInferMeta(const MetaTensor& bboxes, MetaTensor* roisnum, MetaConfig config = MetaConfig()); +void MatrixRankStaticInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + void MatrixRankTolInferMeta(const MetaTensor& x, const MetaTensor& atol_tensor, bool use_default_tol, diff --git a/paddle/phi/ops/compat/matrix_rank_sig.cc b/paddle/phi/ops/compat/matrix_rank_sig.cc index 3a9a0400627..9ea8fbcbe50 100644 --- a/paddle/phi/ops/compat/matrix_rank_sig.cc +++ b/paddle/phi/ops/compat/matrix_rank_sig.cc @@ -18,6 +18,12 @@ namespace phi { // we have to return every specific KernelSignature for infrt now KernelSignature MatrixRankOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsForInferShape()) { + return KernelSignature("matrix_rank_tol", + {"X", "TolTensor"}, + {"use_default_tol", "hermitian"}, + {"Out"}); + } if (ctx.HasInput("TolTensor")) { return KernelSignature("matrix_rank_tol", {"X", "TolTensor"}, -- GitLab