未验证 提交 5f198a6e 编写于 作者: 石晓伟 提交者: GitHub

add op_compat for the seqpool_cvm_concat_fuse_pass, test=develop (#33559)

上级 aa1aac9d
...@@ -52,6 +52,52 @@ static void GetConcatNodes(ir::Graph* graph, std::vector<Node*>* concat_nodes) { ...@@ -52,6 +52,52 @@ static void GetConcatNodes(ir::Graph* graph, std::vector<Node*>* concat_nodes) {
} }
} // anonymous namespace } // anonymous namespace
SeqPoolCVMConcatFusePass::SeqPoolCVMConcatFusePass() {
AddOpCompat(OpCompat("sequence_pool"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("MaxIndex")
.IsTensor()
.IsOptional()
.End()
.AddAttr("pooltype")
.IsStringIn({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"})
.End()
.AddAttr("pad_value")
.End();
AddOpCompat(OpCompat("cvm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("CVM")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("use_cvm")
.IsBoolEQ(true)
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(1)
.End();
}
void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const { void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("seqpool_cvm_concat_fuse", graph); FusePassBase::Init("seqpool_cvm_concat_fuse", graph);
std::vector<Node*> concat_nodes; std::vector<Node*> concat_nodes;
......
...@@ -44,7 +44,7 @@ class Graph; ...@@ -44,7 +44,7 @@ class Graph;
class SeqPoolCVMConcatFusePass : public FusePassBase { class SeqPoolCVMConcatFusePass : public FusePassBase {
public: public:
virtual ~SeqPoolCVMConcatFusePass() {} SeqPoolCVMConcatFusePass();
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
......
type: "cvm"
def {
inputs {
name: "X"
}
inputs {
name: "CVM"
}
outputs {
name: "Y"
}
attrs {
name: "use_cvm"
type: BOOLEAN
}
}
extra {
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
type: "sequence_pool"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
outputs {
name: "MaxIndex"
}
attrs {
name: "pooltype"
type: STRING
}
attrs {
name: "pad_value"
type: FLOAT
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册