From 09fae5cd3e9ccf95356bec4cdd70e5fef9b89650 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Tue, 11 Oct 2022 17:50:41 +0800 Subject: [PATCH] [CINN] support expand with static expand_times (#46776) --- .../framework/paddle2cinn/build_cinn_pass.cc | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 8becee8d485..9aca93dd735 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -81,6 +81,34 @@ OpTransInfo::OpTransInfo() { // judgment condition for the dynamic reshape2 dynamic_op_cond_.emplace("reshape2", dynamic_op_cond_.at("reshape")); + + // judgment condition for the dynamic expand + dynamic_op_cond_.emplace("expand", [](const ir::Node& node) -> bool { + if (!node.IsOp()) { + return false; + } + auto* op_desc = node.Op(); + bool has_expand_times_tensor = + op_desc->Inputs().count("expand_times_tensor") && + op_desc->Inputs().at("expand_times_tensor").size(); + bool has_expand_times = op_desc->Inputs().count("ExpandTimes") && + op_desc->Inputs().at("ExpandTimes").size(); + return has_expand_times_tensor || has_expand_times; + }); + + // judgment condition for the dynamic expand_v2 + dynamic_op_cond_.emplace("expand_v2", [](const ir::Node& node) -> bool { + if (!node.IsOp()) { + return false; + } + auto* op_desc = node.Op(); + bool has_expand_shapes_tensor = + op_desc->Inputs().count("expand_shapes_tensor") && + op_desc->Inputs().at("expand_shapes_tensor").size(); + bool has_shape = op_desc->Inputs().count("Shape") && + op_desc->Inputs().at("Shape").size(); + return has_expand_shapes_tensor || has_shape; + }); } std::unordered_set OpTransInfo::GetDenyVarNames( -- GitLab