未验证 提交 1593c7ca 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Fix infershape if encounter TensorList and Attr("XXX") (#40420)

* [Phi] Fix infershape if encounter TensorList and Attr("XXX")

* add InferShapeArgumentMappingContext
上级 bd2d4fd0
......@@ -90,6 +90,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
bool IsForInferShape() const override { return true; }
bool IsRuntime() const override { return ctx_.IsRuntime(); }
private:
const InferShapeContext& ctx_;
};
......
......@@ -96,6 +96,13 @@ class ArgumentMappingContext {
// use this function to mark it comes from InferShapeArgumentMappingContext
// and will be used in infershape
virtual bool IsForInferShape() const = 0;
// NOTE(paddle-dev): [ Why do we export this interface? ]
// In old Fluid framework, some operators' Attribute can be a Tensor or
// TensorList. In this case, the InferShape logic will be different
// under CompileTime and RuntimeTime. So we export this interface to
// handle it conveniently. See "gaussian_random_sig.cc" for details.
virtual bool IsRuntime() const { return true; }
};
} // namespace phi
......@@ -18,14 +18,23 @@ namespace phi {
KernelSignature GaussianRandomOpArgumentMapping(
const ArgumentMappingContext& ctx) {
const auto& shape = paddle::any_cast<std::vector<int64_t>>(ctx.Attr("shape"));
if (ctx.InputSize("ShapeTensorList") > 0) {
// Infer output shape by Attr("shape") in CompileTime if it is specified.
if (!ctx.IsRuntime() && !shape.empty()) {
return KernelSignature("gaussian_random",
{},
{"shape", "mean", "std", "seed", "dtype"},
{"Out"});
} else {
return KernelSignature(
"gaussian_random",
{},
{"ShapeTensorList", "mean", "std", "seed", "dtype"},
{"Out"});
}
}
const auto& shape = paddle::any_cast<std::vector<int64_t>>(ctx.Attr("shape"));
if (ctx.HasInput("ShapeTensor") && shape.empty()) {
return KernelSignature("gaussian_random",
{},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册