未验证 提交 1593c7ca 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Fix infershape if encounter TensorList and Attr("XXX") (#40420)

* [Phi] Fix infershape if encounter TensorList and Attr("XXX")

* add InferShapeArgumentMappingContext
上级 bd2d4fd0
...@@ -90,6 +90,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -90,6 +90,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
bool IsForInferShape() const override { return true; } bool IsForInferShape() const override { return true; }
bool IsRuntime() const override { return ctx_.IsRuntime(); }
private: private:
const InferShapeContext& ctx_; const InferShapeContext& ctx_;
}; };
......
...@@ -96,6 +96,13 @@ class ArgumentMappingContext { ...@@ -96,6 +96,13 @@ class ArgumentMappingContext {
// use this function to mark it comes from InferShapeArgumentMappingContext // use this function to mark it comes from InferShapeArgumentMappingContext
// and will be used in infershape // and will be used in infershape
virtual bool IsForInferShape() const = 0; virtual bool IsForInferShape() const = 0;
// NOTE(paddle-dev): [ Why do we export this interface? ]
// In old Fluid framework, some operators' Attribute can be a Tensor or
// TensorList. In this case, the InferShape logic will be different
// under CompileTime and RuntimeTime. So we export this interface to
// handle it conveniently. See "gaussian_random_sig.cc" for details.
virtual bool IsRuntime() const { return true; }
}; };
} // namespace phi } // namespace phi
...@@ -18,14 +18,23 @@ namespace phi { ...@@ -18,14 +18,23 @@ namespace phi {
KernelSignature GaussianRandomOpArgumentMapping( KernelSignature GaussianRandomOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
const auto& shape = paddle::any_cast<std::vector<int64_t>>(ctx.Attr("shape"));
if (ctx.InputSize("ShapeTensorList") > 0) { if (ctx.InputSize("ShapeTensorList") > 0) {
return KernelSignature("gaussian_random", // Infer output shape by Attr("shape") in CompileTime if it is specified.
{}, if (!ctx.IsRuntime() && !shape.empty()) {
{"ShapeTensorList", "mean", "std", "seed", "dtype"}, return KernelSignature("gaussian_random",
{"Out"}); {},
{"shape", "mean", "std", "seed", "dtype"},
{"Out"});
} else {
return KernelSignature(
"gaussian_random",
{},
{"ShapeTensorList", "mean", "std", "seed", "dtype"},
{"Out"});
}
} }
const auto& shape = paddle::any_cast<std::vector<int64_t>>(ctx.Attr("shape"));
if (ctx.HasInput("ShapeTensor") && shape.empty()) { if (ctx.HasInput("ShapeTensor") && shape.empty()) {
return KernelSignature("gaussian_random", return KernelSignature("gaussian_random",
{}, {},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册