diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 91ef59575c3aa2a737f32c0ca90a7cbb2b3f3744..29c7f5d0ce73cbf1af18e6f5869d59d2200917ad 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -90,6 +90,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { bool IsForInferShape() const override { return true; } + bool IsRuntime() const override { return ctx_.IsRuntime(); } + private: const InferShapeContext& ctx_; }; diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 688a0e54a0cf4f0f041704b03c5d256a7c17d1ec..25b80279ecf10619d97b8800b24ab5353c79745d 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -96,6 +96,13 @@ class ArgumentMappingContext { // use this function to mark it comes from InferShapeArgumentMappingContext // and will be used in infershape virtual bool IsForInferShape() const = 0; + + // NOTE(paddle-dev): [ Why do we export this interface? ] + // In old Fluid framework, some operators' Attribute can be a Tensor or + // TensorList. In this case, the InferShape logic will be different + // under CompileTime and RuntimeTime. So we export this interface to + // handle it conveniently. See "gaussian_random_sig.cc" for details. + virtual bool IsRuntime() const { return true; } }; } // namespace phi diff --git a/paddle/phi/ops/compat/gaussian_random_sig.cc b/paddle/phi/ops/compat/gaussian_random_sig.cc index cddcb80ebea3ddcae345789497ca8006301f7a6e..2f2b157e4c0f950c75ce2d8a66455127f51752fa 100644 --- a/paddle/phi/ops/compat/gaussian_random_sig.cc +++ b/paddle/phi/ops/compat/gaussian_random_sig.cc @@ -18,14 +18,23 @@ namespace phi { KernelSignature GaussianRandomOpArgumentMapping( const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); if (ctx.InputSize("ShapeTensorList") > 0) { - return KernelSignature("gaussian_random", - {}, - {"ShapeTensorList", "mean", "std", "seed", "dtype"}, - {"Out"}); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature("gaussian_random", + {}, + {"shape", "mean", "std", "seed", "dtype"}, + {"Out"}); + } else { + return KernelSignature( + "gaussian_random", + {}, + {"ShapeTensorList", "mean", "std", "seed", "dtype"}, + {"Out"}); + } } - const auto& shape = paddle::any_cast>(ctx.Attr("shape")); if (ctx.HasInput("ShapeTensor") && shape.empty()) { return KernelSignature("gaussian_random", {},