From 2e7365313b6cc5eb3eba74d23cd1558020bcdcdf Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Thu, 24 Mar 2022 14:27:39 +0800 Subject: [PATCH] [Phi] Migrate InferShape of multiplex, qr, tril_triu (#40102) * migrate infershape * fix tril_triu infershape error * fix qr_op infershape * add parse qr mode func * move order --- paddle/fluid/operators/multiplex_op.cc | 48 ++++----------------- paddle/fluid/operators/qr_op.cc | 48 ++++----------------- paddle/fluid/operators/tril_triu_op.cc | 24 ++++------- paddle/phi/infermeta/multiary.cc | 44 ++++++++++++++++++++ paddle/phi/infermeta/multiary.h | 4 ++ paddle/phi/infermeta/unary.cc | 53 ++++++++++++++++++++++++ paddle/phi/infermeta/unary.h | 10 +++++ paddle/phi/kernels/cpu/qr_kernel.cc | 24 +---------- paddle/phi/kernels/funcs/parse_qr_mode.h | 41 ++++++++++++++++++ 9 files changed, 179 insertions(+), 117 deletions(-) create mode 100644 paddle/phi/kernels/funcs/parse_qr_mode.h diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index 8771a6573c..4e6ad35e61 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -14,8 +14,13 @@ limitations under the License. */ #include #include + +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -25,44 +30,6 @@ class MultiplexOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "Multiplex"); - PADDLE_ENFORCE_NE( - ctx->Inputs("X").empty(), true, - platform::errors::InvalidArgument("MultiInput(X) shouldn't be empty.")); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multiplex"); - auto ids_dim = ctx->GetInputDim("Ids"); - PADDLE_ENFORCE_EQ( - ids_dim.size(), 2, - platform::errors::PreconditionNotMet( - "The index tensor must be a vector with 2 dimensions")); - PADDLE_ENFORCE_EQ( - ids_dim[1], 1, - platform::errors::PreconditionNotMet( - "The index tensor must be a vector with batchSize x 1.")); - - auto ins_dims = ctx->GetInputsDim("X"); - auto num_ins = ins_dims.size(); - PADDLE_ENFORCE_GT(num_ins, 1, - platform::errors::InvalidArgument( - "multiplex operator should have more than " - "one candidate input tensors.")); - - auto in_dim = ins_dims[0]; - PADDLE_ENFORCE_GE( - in_dim.size(), 2, - platform::errors::InvalidArgument( - "The rank of candidate tensors must be not less than 2.")); - for (size_t i = 1; i < num_ins; i++) { - auto dim = ins_dims[i]; - PADDLE_ENFORCE_EQ( - in_dim, dim, - platform::errors::PreconditionNotMet( - "All the candidate tensors must have the same size.")); - } - ctx->SetOutputDim("Out", in_dim); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -164,8 +131,11 @@ class MultiplexGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor, + PD_INFER_META(phi::MultiplexInferMeta)); REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, ops::MultiplexGradMaker, - ops::MultiplexGradMaker); + ops::MultiplexGradMaker, + MultiplexInferShapeFunctor); REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp); diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc index 82fc9ef1b7..02d5e5f03f 100644 --- a/paddle/fluid/operators/qr_op.cc +++ b/paddle/fluid/operators/qr_op.cc @@ -21,6 +21,10 @@ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -29,43 +33,6 @@ using DDim = framework::DDim; class QrOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "qr"); - OP_INOUT_CHECK(ctx->HasOutput("Q"), "Output", "Q", "qr"); - OP_INOUT_CHECK(ctx->HasOutput("R"), "Output", "R", "qr"); - - auto x_dims = ctx->GetInputDim("X"); - int x_rank = x_dims.size(); - PADDLE_ENFORCE_GE(x_dims.size(), 2, - platform::errors::InvalidArgument( - "the rank of input must greater than 2")); - bool compute_q; - bool reduced_mode; - int m = x_dims[x_rank - 2]; - int n = x_dims[x_rank - 1]; - int min_mn = std::min(m, n); - std::string mode = ctx->Attrs().Get("mode"); - std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); - - if (compute_q) { - int k = reduced_mode ? min_mn : m; - auto q_dims_vec = phi::vectorize(x_dims); - q_dims_vec[q_dims_vec.size() - 1] = k; - ctx->SetOutputDim("Q", phi::make_ddim(q_dims_vec)); - } else { - ctx->SetOutputDim("Q", phi::make_ddim({0})); - } - - int k = reduced_mode ? min_mn : m; - auto r_dims_vec = phi::vectorize(x_dims); - r_dims_vec[r_dims_vec.size() - 2] = k; - r_dims_vec[r_dims_vec.size() - 1] = n; - ctx->SetOutputDim("R", phi::make_ddim(r_dims_vec)); - - ctx->ShareLoD("X", /*->*/ "Q"); - ctx->ShareLoD("X", /*->*/ "R"); - } }; class QrOpMaker : public framework::OpProtoAndCheckerMaker { @@ -83,10 +50,8 @@ class QrOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault("reduced"); AddComment(R"DOC( Qr Operator. - This operator is used to perform QR operation for batched matrics $X$. $$Q, R = qr(X)$$ - )DOC"); } }; @@ -138,10 +103,13 @@ class QrGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor, + PD_INFER_META(phi::QrInferMeta)); REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker, ops::QrGradMaker, - ops::QrGradMaker); + ops::QrGradMaker, + QrInferShapeFunctor); REGISTER_OPERATOR(qr_grad, ops::QrGradOp); diff --git a/paddle/fluid/operators/tril_triu_op.cc b/paddle/fluid/operators/tril_triu_op.cc index c8010e8a12..b941fa3d03 100644 --- a/paddle/fluid/operators/tril_triu_op.cc +++ b/paddle/fluid/operators/tril_triu_op.cc @@ -13,29 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { class TrilTriuOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::NotFound("Input(X) of TrilTriuOp is not found.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound("Output(Out) of TrilTriuOp is not found.")); - const auto& x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_GE(x_dims.size(), 2, - platform::errors::InvalidArgument( - "Input(X)'s rank must be at least 2 in TrilTriuOp.")); - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker { @@ -100,7 +89,10 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor, + PD_INFER_META(phi::TrilTriuInferMeta)); REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ops::TrilTriuGradOpMaker, - ops::TrilTriuGradOpMaker); + ops::TrilTriuGradOpMaker, + TrilTriuInferShapeFunctor); REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 4790fa863f..3aa4976062 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -832,6 +832,50 @@ void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { out->share_lod(*x.at(0)); } +void MultiplexInferMeta(const std::vector& ins, + const MetaTensor& ids, + MetaTensor* out) { + PADDLE_ENFORCE_NE( + ins.empty(), + true, + phi::errors::InvalidArgument("MultiInput(X) shouldn't be empty.")); + auto ids_dim = ids.dims(); + PADDLE_ENFORCE_EQ(ids_dim.size(), + 2, + phi::errors::PreconditionNotMet( + "The index tensor must be a vector with 2 dimensions")); + PADDLE_ENFORCE_EQ( + ids_dim[1], + 1, + phi::errors::PreconditionNotMet( + "The index tensor must be a vector with batchSize x 1.")); + + auto ins_dims = GetMetaTensorsDim(ins); + auto num_ins = ins_dims.size(); + PADDLE_ENFORCE_GT( + num_ins, + 1, + phi::errors::InvalidArgument("multiplex operator should have more than " + "one candidate input tensors.")); + + auto in_dim = ins_dims[0]; + PADDLE_ENFORCE_GE( + in_dim.size(), + 2, + phi::errors::InvalidArgument( + "The rank of candidate tensors must be not less than 2.")); + for (size_t i = 1; i < num_ins; i++) { + auto dim = ins_dims[i]; + PADDLE_ENFORCE_EQ( + in_dim, + dim, + phi::errors::PreconditionNotMet( + "All the candidate tensors must have the same size.")); + } + out->set_dims(in_dim); + out->set_dtype(ins[0]->dtype()); +} + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, paddle::optional rois_num, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 9088f20481..ddd7c132fb 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -152,6 +152,10 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x, void MultiDotInferMeta(const std::vector& x, MetaTensor* out); +void MultiplexInferMeta(const std::vector& ins, + const MetaTensor& ids, + MetaTensor* out); + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, paddle::optional rois_num, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 160e8ef56f..b76661d49b 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" @@ -1129,6 +1130,44 @@ void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_layout(x.layout()); } +void QrInferMeta(const MetaTensor& x, + const std::string& mode, + MetaTensor* q, + MetaTensor* r) { + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + PADDLE_ENFORCE_GE( + x_dims.size(), + 2, + phi::errors::InvalidArgument("the rank of input must greater than 2")); + bool compute_q; + bool reduced_mode; + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); + + if (compute_q) { + int k = reduced_mode ? min_mn : m; + auto q_dims_vec = phi::vectorize(x_dims); + q_dims_vec[q_dims_vec.size() - 1] = k; + q->set_dims(phi::make_ddim(q_dims_vec)); + } else { + q->set_dims(phi::make_ddim({0})); + } + + int k = reduced_mode ? min_mn : m; + auto r_dims_vec = phi::vectorize(x_dims); + r_dims_vec[r_dims_vec.size() - 2] = k; + r_dims_vec[r_dims_vec.size() - 1] = n; + r->set_dims(phi::make_ddim(r_dims_vec)); + + q->share_lod(x); + r->share_lod(x); + q->set_dtype(x.dtype()); + r->set_dtype(x.dtype()); +} + DDim ReduceInferDim(const MetaTensor& x, const std::vector& axis, bool keep_dim, @@ -1847,6 +1886,20 @@ void UnbindInferMeta(const MetaTensor& x, } } +void TrilTriuInferMeta(const MetaTensor& x, + int diagonal, + bool lower, + MetaTensor* out) { + const auto& x_dims = x.dims(); + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(X)'s rank must be at least 2 in TrilTriuOp.")); + out->set_dims(x.dims()); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) { out->share_meta(x); } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6187c49de1..8e254965ab 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -180,6 +180,11 @@ void PoolInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void QrInferMeta(const MetaTensor& x, + const std::string& mode, + MetaTensor* q, + MetaTensor* r); + void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out); void ReduceInferMeta(const MetaTensor& x, @@ -282,6 +287,11 @@ void TransposeGradInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out); +void TrilTriuInferMeta(const MetaTensor& x, + int diagonal, + bool lower, + MetaTensor* out); + void UnbindInferMeta(const MetaTensor& x, int axis, std::vector* outs); diff --git a/paddle/phi/kernels/cpu/qr_kernel.cc b/paddle/phi/kernels/cpu/qr_kernel.cc index e2e3256744..b0e82cedb6 100644 --- a/paddle/phi/kernels/cpu/qr_kernel.cc +++ b/paddle/phi/kernels/cpu/qr_kernel.cc @@ -19,30 +19,10 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/parse_qr_mode.h" namespace phi { -static inline std::tuple ParseQrMode(const std::string& mode) { - bool compute_q; - bool reduced; - if (mode == "reduced") { - compute_q = true; - reduced = true; - } else if (mode == "complete") { - compute_q = true; - reduced = false; - } else if (mode == "r") { - compute_q = false; - reduced = true; - } else { - PADDLE_THROW(errors::InvalidArgument( - "QR received unrecognized mode '%s'" - " but expected one of 'reduced' (default), 'r', or 'complete'", - mode)); - } - return std::make_tuple(compute_q, reduced); -} - template void QrKernel(const Context& ctx, const DenseTensor& x, @@ -51,7 +31,7 @@ void QrKernel(const Context& ctx, DenseTensor* r) { bool compute_q; bool reduced_mode; - std::tie(compute_q, reduced_mode) = ParseQrMode(mode); + std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); auto numel = x.numel(); PADDLE_ENFORCE_GT( numel, 0, errors::PreconditionNotMet("The input of QR is empty.")); diff --git a/paddle/phi/kernels/funcs/parse_qr_mode.h b/paddle/phi/kernels/funcs/parse_qr_mode.h new file mode 100644 index 0000000000..adf64759d3 --- /dev/null +++ b/paddle/phi/kernels/funcs/parse_qr_mode.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace funcs { + +static inline std::tuple ParseQrMode(const std::string& mode) { + bool compute_q; + bool reduced; + if (mode == "reduced") { + compute_q = true; + reduced = true; + } else if (mode == "complete") { + compute_q = true; + reduced = false; + } else if (mode == "r") { + compute_q = false; + reduced = true; + } else { + PADDLE_THROW(errors::InvalidArgument( + "QR received unrecognized mode '%s'" + " but expected one of 'reduced' (default), 'r', or 'complete'", + mode)); + } + return std::make_tuple(compute_q, reduced); +} +} // namespace funcs +} // namespace phi -- GitLab