From 5f198a6e5f42b9e7e995c6419a3e943fddaeaa0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Wed, 23 Jun 2021 11:17:01 +0800 Subject: [PATCH] add op_compat for the seqpool_cvm_concat_fuse_pass, test=develop (#33559) --- .../ir/seqpool_cvm_concat_fuse_pass.cc | 46 ++++++++++++++++++ .../ir/seqpool_cvm_concat_fuse_pass.h | 2 +- paddle/fluid/operators/compat/cvm.pbtxt | 39 +++++++++++++++ .../operators/compat/sequence_pool.pbtxt | 47 +++++++++++++++++++ 4 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/compat/cvm.pbtxt create mode 100644 paddle/fluid/operators/compat/sequence_pool.pbtxt diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc index 6bff4a05627..effaa0814ea 100644 --- a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc @@ -52,6 +52,52 @@ static void GetConcatNodes(ir::Graph* graph, std::vector* concat_nodes) { } } // anonymous namespace +SeqPoolCVMConcatFusePass::SeqPoolCVMConcatFusePass() { + AddOpCompat(OpCompat("sequence_pool")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("MaxIndex") + .IsTensor() + .IsOptional() + .End() + .AddAttr("pooltype") + .IsStringIn({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"}) + .End() + .AddAttr("pad_value") + .End(); + AddOpCompat(OpCompat("cvm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("CVM") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("use_cvm") + .IsBoolEQ(true) + .End(); + AddOpCompat(OpCompat("concat")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(1) + .End(); +} + void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init("seqpool_cvm_concat_fuse", graph); std::vector concat_nodes; diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h index b0a3573fb59..7680c30e485 100644 --- a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h @@ -44,7 +44,7 @@ class Graph; class SeqPoolCVMConcatFusePass : public FusePassBase { public: - virtual ~SeqPoolCVMConcatFusePass() {} + SeqPoolCVMConcatFusePass(); protected: void ApplyImpl(ir::Graph* graph) const override; diff --git a/paddle/fluid/operators/compat/cvm.pbtxt b/paddle/fluid/operators/compat/cvm.pbtxt new file mode 100644 index 00000000000..ccbeabc1f15 --- /dev/null +++ b/paddle/fluid/operators/compat/cvm.pbtxt @@ -0,0 +1,39 @@ +type: "cvm" +def { + inputs { + name: "X" + } + inputs { + name: "CVM" + } + outputs { + name: "Y" + } + attrs { + name: "use_cvm" + type: BOOLEAN + } +} +extra { + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} + diff --git a/paddle/fluid/operators/compat/sequence_pool.pbtxt b/paddle/fluid/operators/compat/sequence_pool.pbtxt new file mode 100644 index 00000000000..c45f457fe0d --- /dev/null +++ b/paddle/fluid/operators/compat/sequence_pool.pbtxt @@ -0,0 +1,47 @@ +type: "sequence_pool" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + outputs { + name: "MaxIndex" + } + attrs { + name: "pooltype" + type: STRING + } + attrs { + name: "pad_value" + type: FLOAT + } +} +extra { + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} + -- GitLab