未验证 提交 09fae5cd 编写于 作者: A Aganlengzi 提交者: GitHub

[CINN] support expand with static expand_times (#46776)

上级 75528ad6
...@@ -81,6 +81,34 @@ OpTransInfo::OpTransInfo() { ...@@ -81,6 +81,34 @@ OpTransInfo::OpTransInfo() {
// judgment condition for the dynamic reshape2 // judgment condition for the dynamic reshape2
dynamic_op_cond_.emplace("reshape2", dynamic_op_cond_.at("reshape")); dynamic_op_cond_.emplace("reshape2", dynamic_op_cond_.at("reshape"));
// judgment condition for the dynamic expand
dynamic_op_cond_.emplace("expand", [](const ir::Node& node) -> bool {
if (!node.IsOp()) {
return false;
}
auto* op_desc = node.Op();
bool has_expand_times_tensor =
op_desc->Inputs().count("expand_times_tensor") &&
op_desc->Inputs().at("expand_times_tensor").size();
bool has_expand_times = op_desc->Inputs().count("ExpandTimes") &&
op_desc->Inputs().at("ExpandTimes").size();
return has_expand_times_tensor || has_expand_times;
});
// judgment condition for the dynamic expand_v2
dynamic_op_cond_.emplace("expand_v2", [](const ir::Node& node) -> bool {
if (!node.IsOp()) {
return false;
}
auto* op_desc = node.Op();
bool has_expand_shapes_tensor =
op_desc->Inputs().count("expand_shapes_tensor") &&
op_desc->Inputs().at("expand_shapes_tensor").size();
bool has_shape = op_desc->Inputs().count("Shape") &&
op_desc->Inputs().at("Shape").size();
return has_expand_shapes_tensor || has_shape;
});
} }
std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册