From 1593c7ca2ad4c757da69581cb99f5d6250df2263 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 11 Mar 2022 12:55:54 +0800 Subject: [PATCH] [Phi] Fix infershape if encounter TensorList and Attr("XXX") (#40420) * [Phi] Fix infershape if encounter TensorList and Attr("XXX") * add InferShapeArgumentMappingContext --- paddle/fluid/framework/infershape_utils.cc | 2 ++ paddle/phi/core/compat/arg_map_context.h | 7 +++++++ paddle/phi/ops/compat/gaussian_random_sig.cc | 19 ++++++++++++++----- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 91ef59575c..29c7f5d0ce 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -90,6 +90,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { bool IsForInferShape() const override { return true; } + bool IsRuntime() const override { return ctx_.IsRuntime(); } + private: const InferShapeContext& ctx_; }; diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 688a0e54a0..25b80279ec 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -96,6 +96,13 @@ class ArgumentMappingContext { // use this function to mark it comes from InferShapeArgumentMappingContext // and will be used in infershape virtual bool IsForInferShape() const = 0; + + // NOTE(paddle-dev): [ Why do we export this interface? ] + // In old Fluid framework, some operators' Attribute can be a Tensor or + // TensorList. In this case, the InferShape logic will be different + // under CompileTime and RuntimeTime. So we export this interface to + // handle it conveniently. See "gaussian_random_sig.cc" for details. + virtual bool IsRuntime() const { return true; } }; } // namespace phi diff --git a/paddle/phi/ops/compat/gaussian_random_sig.cc b/paddle/phi/ops/compat/gaussian_random_sig.cc index cddcb80ebe..2f2b157e4c 100644 --- a/paddle/phi/ops/compat/gaussian_random_sig.cc +++ b/paddle/phi/ops/compat/gaussian_random_sig.cc @@ -18,14 +18,23 @@ namespace phi { KernelSignature GaussianRandomOpArgumentMapping( const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); if (ctx.InputSize("ShapeTensorList") > 0) { - return KernelSignature("gaussian_random", - {}, - {"ShapeTensorList", "mean", "std", "seed", "dtype"}, - {"Out"}); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature("gaussian_random", + {}, + {"shape", "mean", "std", "seed", "dtype"}, + {"Out"}); + } else { + return KernelSignature( + "gaussian_random", + {}, + {"ShapeTensorList", "mean", "std", "seed", "dtype"}, + {"Out"}); + } } - const auto& shape = paddle::any_cast>(ctx.Attr("shape")); if (ctx.HasInput("ShapeTensor") && shape.empty()) { return KernelSignature("gaussian_random", {}, -- GitLab