From 090396368c80360fc33d09dfb1df7492f7dfb544 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 1 Mar 2022 19:23:04 +0800 Subject: [PATCH] [Phi]rm reduce infershape (#39820) * modify infershape utils and rm reduce infershape * merge develop * fix infermete bug * add IsForInferShape func in ArgumentMappingContext * add reduce_mean infermeta * modify annotation * add default dims --- paddle/fluid/framework/infershape_utils.cc | 6 +- paddle/fluid/framework/operator.h | 2 + .../operators/reduce_ops/reduce_mean_op.cc | 10 +++- .../operators/reduce_ops/reduce_sum_op.cc | 10 +++- .../dialect/phi/pass/proto_arg_map_context.h | 2 + paddle/phi/core/compat/arg_map_context.h | 4 ++ paddle/phi/infermeta/unary.cc | 60 +++++++++++++++---- paddle/phi/infermeta/unary.h | 15 +++-- paddle/phi/kernels/math_kernel.h | 2 +- paddle/phi/ops/compat/reduce_sig.cc | 34 +++++++---- paddle/phi/tests/ops/test_op_signature.h | 2 + python/paddle/utils/code_gen/api.yaml | 2 +- 12 files changed, 117 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index d9287b9a624..57fb68e8042 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -88,6 +88,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { return var_types[0] == proto::VarType::SELECTED_ROWS; } + bool IsForInferShape() const override { return true; } + private: const InferShapeContext& ctx_; }; @@ -127,7 +129,9 @@ class CompatMetaTensor : public phi::MetaTensor { } } else { auto* var = BOOST_GET_CONST(VarDesc*, var_); - return phi::make_ddim(var->GetShape()); + + return var->GetShape().empty() ? phi::make_ddim({0UL}) + : phi::make_ddim(var->GetShape()); } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 16718a31651..e33d4feb82a 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -489,6 +489,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { return ctx_.OutputVar(name)->IsType(); } + bool IsForInferShape() const override { return false; } + private: const ExecutionContext& ctx_; }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index e80df5f95bb..6157a3a925d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -18,6 +18,10 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { @@ -92,9 +96,13 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_mean"; } }; +DELCARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor, + PT_INFER_META(phi::MeanRawInferMeta)); + REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__, ops::ReduceMeanOpGradMaker, - ops::ReduceMeanOpGradMaker); + ops::ReduceMeanOpGradMaker, + ReduceMeanInferShapeFunctor); REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp, ops::ReduceMeanDoubleGradDescMaker, ops::ReduceMeanDoubleGradOpBaseMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index bdab14a18a0..8ef0712dc7a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -16,6 +16,10 @@ #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace framework { class OpDesc; @@ -98,10 +102,14 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_sum"; } }; +DELCARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor, + PT_INFER_META(phi::ReduceInferMetaBase)); + REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, ops::ReduceSumVarTypeInference, ops::ReduceSumOpGradMaker, - ops::ReduceSumOpGradMaker); + ops::ReduceSumOpGradMaker, + ReduceSumInferShapeFunctor); REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, ops::ReduceSumDoubleOpGradMaker, ops::ReduceSumDoubleOpGradMaker, diff --git a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h index 843b19d217f..ca8a22a7e75 100644 --- a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h +++ b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h @@ -46,6 +46,8 @@ class ProtoArgumentMappingContext : public phi::ArgumentMappingContext { bool IsDenseTensorOutput(const std::string& name) const override; bool IsSelectedRowsOutput(const std::string& name) const override; + bool IsForInferShape() const override { return false; } + private: mlir::Operation* op_; const std::unordered_map& input_map_; diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index af29b3bab5c..f625d57df2e 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -91,6 +91,10 @@ class ArgumentMappingContext { virtual bool IsDenseTensorOutput(const std::string& name) const = 0; virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; + + // use this function to mark it comes from InferShapeArgumentMappingContext + // and will be used in infershape + virtual bool IsForInferShape() const = 0; }; } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 4696187bd23..983e0162264 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -375,7 +375,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ReshapeInferMeta(x, shape, out, config); } -/* Why not use ReduceInferMeta directly? +/* Why not use ReduceInferMetaBase directly? Because we need make InferMetaFunction's args follow the design of api.yaml */ void SumInferMeta(const MetaTensor& x, @@ -383,22 +383,53 @@ void SumInferMeta(const MetaTensor& x, DataType dtype, bool keep_dim, MetaTensor* out) { - ReduceInferMetaBase(x, axis, keep_dim, dtype, out); + bool reduce_all = false; + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out); } void ReduceInferMetaBase(const MetaTensor& x, const std::vector& axis, bool keep_dim, + bool reduce_all, DataType dtype, MetaTensor* out) { - bool reduce_all = true; - std::set dims_set(axis.begin(), axis.end()); + auto x_rank = x.dims().size(); + + std::vector formated_axis = axis; + for (size_t i = 0; i < axis.size(); ++i) { + PADDLE_ENFORCE_LT(axis[i], + x_rank, + errors::InvalidArgument( + "The reduce dim index %d should be in the " + "range [-dimension(X), dimension(X)] " + "which dimesion = %d. But received dim index = %d.", + i, + x_rank, + axis[i])); + PADDLE_ENFORCE_GE(axis[i], + -x_rank, + errors::InvalidArgument( + "The reduce dim index %d should be in the " + "range [-dimension(X), dimension(X)] " + "which dimesion = %d. But received dim index = %d.", + i, + x_rank, + axis[i])); + + if (axis[i] < 0) { + formated_axis[i] = axis[i] + x_rank; + } + } + + bool full_dim = true; + std::set dims_set(formated_axis.begin(), formated_axis.end()); for (int64_t i = 0; i < x.dims().size(); ++i) { if (dims_set.find(i) == dims_set.end()) { - reduce_all = false; + full_dim = false; break; } } + reduce_all = reduce_all || full_dim; std::vector out_dim_vector; if (keep_dim) { @@ -441,11 +472,20 @@ void ReduceInferMetaBase(const MetaTensor& x, out->set_layout(x.layout()); } -void ReduceInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out) { - ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out); +void MeanRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out) { + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); +} + +void MeanInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out) { + bool reduce_all = false; + ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); } void TransferLayoutInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index b3929b9d2b4..a2d779e0f70 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -86,13 +86,20 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x, const std::vector& axis, bool keep_dim, + bool reduce_all, DataType dtype, MetaTensor* out); -void ReduceInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - MetaTensor* out); +void MeanRawInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out); + +void MeanInferMeta(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + MetaTensor* out); void SumInferMeta(const MetaTensor& x, const std::vector& axis, diff --git a/paddle/phi/kernels/math_kernel.h b/paddle/phi/kernels/math_kernel.h index c6036f4a042..342393d79bd 100644 --- a/paddle/phi/kernels/math_kernel.h +++ b/paddle/phi/kernels/math_kernel.h @@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, bool keep_dim) { auto dense_out = phi::Empty(dev_ctx); MetaTensor meta_out(&dense_out); - ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out); + ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out); MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index 74704671f8b..6395486ed2b 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -17,28 +17,36 @@ limitations under the License. */ namespace phi { KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { - bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); if (ctx.IsDenseTensorInput("X")) { - if (!reduce_all) { - return KernelSignature( - "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in + // InferShape, so we must return the "sum_raw" KernelSignature. + // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with + // the "sum_raw" KernelSignature + if (ctx.IsForInferShape() || reduce_all) { + return KernelSignature("sum_raw", + {"X"}, + {"dim", "keep_dim", "reduce_all", "out_dtype"}, + {"Out"}); } - return KernelSignature("sum_raw", - {"X"}, - {"dim", "keep_dim", "reduce_all", "out_dtype"}, - {"Out"}); + return KernelSignature( + "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); } return KernelSignature("unregistered", {}, {}, {}); } KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { - bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); if (ctx.IsDenseTensorInput("X")) { - if (!reduce_all) { - return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); + bool reduce_all = paddle::any_cast(ctx.Attr("reduce_all")); + // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in + // InferShape, so we must return the "mean_raw" KernelSignature. + // And the InferMeta function(i.e. MeanRawInferMeta) is accordance with the + // "mean_raw" KernelSignature + if (ctx.IsForInferShape() || reduce_all) { + return KernelSignature( + "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); } - return KernelSignature( - "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); } return KernelSignature("unregistered", {}, {}, {}); } diff --git a/paddle/phi/tests/ops/test_op_signature.h b/paddle/phi/tests/ops/test_op_signature.h index fcd2d397fa2..06048f33d94 100644 --- a/paddle/phi/tests/ops/test_op_signature.h +++ b/paddle/phi/tests/ops/test_op_signature.h @@ -80,6 +80,8 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { return selected_rows_outputs.count(name) > 0; } + bool IsForInferShape() const override { return false; } + private: const std::unordered_set dense_tensor_inputs; const std::unordered_set selected_rows_inputs; diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 7ea8493b67f..45a6aae5e6d 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -124,7 +124,7 @@ args : (Tensor x, int64_t[] axis={}, bool keep_dim=false) output : Tensor infer_meta : - func : ReduceInferMeta + func : MeanInferMeta kernel : func : mean -- GitLab