未验证 提交 a863cf73 编写于 作者: F feng_shuai 提交者: GitHub

unsqueeze2_eltwise_fuse_pass_init (#33808)

上级 34466911
...@@ -73,6 +73,46 @@ PDNode *UnsqueezeEltwise::operator()(PDNode *x, PDNode *y) { ...@@ -73,6 +73,46 @@ PDNode *UnsqueezeEltwise::operator()(PDNode *x, PDNode *y) {
} // namespace patterns } // namespace patterns
UnsqueezeEltwiseFusePass::UnsqueezeEltwiseFusePass() {
AddOpCompat(OpCompat("unsqueeze2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("AxesTensor")
.IsOptional()
.IsTensor()
.End()
.AddInput("AxesTensorList")
.IsOptional()
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axes")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("elementwise_mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// The attribute value is - 1 before fusion and 0 after fusion
.AddAttr("axis")
.IsIntIn({-1, 0})
.End();
}
void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
...@@ -100,7 +140,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -100,7 +140,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const {
LOG(WARNING) << "The subgraph is empty."; LOG(WARNING) << "The subgraph is empty.";
return; return;
} }
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle UnsqueezeEltwise fuse"; VLOG(4) << "handle UnsqueezeEltwise fuse";
GET_IR_NODE_FROM_SUBGRAPH(eltwise_op, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_op, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, fused_pattern);
...@@ -123,6 +166,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -123,6 +166,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const {
IR_NODE_LINK_TO(eltwise_op, eltwise_out); IR_NODE_LINK_TO(eltwise_op, eltwise_out);
GraphSafeRemoveNodes(graph, {unsqz_op, unsqz_out}); GraphSafeRemoveNodes(graph, {unsqz_op, unsqz_out});
found_subgraph_count++; found_subgraph_count++;
if (!IsCompat(*eltwise_op->Op())) {
LOG(WARNING) << "unsqueeze2_eltwise_fuse_pass op compat failed.";
return;
}
} }
}; };
......
...@@ -34,6 +34,7 @@ class Graph; ...@@ -34,6 +34,7 @@ class Graph;
// it maybe change in runtime. // it maybe change in runtime.
class UnsqueezeEltwiseFusePass : public FusePassBase { class UnsqueezeEltwiseFusePass : public FusePassBase {
public: public:
UnsqueezeEltwiseFusePass();
virtual ~UnsqueezeEltwiseFusePass() {} virtual ~UnsqueezeEltwiseFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册