diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index 7daa9b5eff7d7ba25f38726efc88b23e072c491d..4101d593086cdbf8848034cd478e068c95d8f790 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -255,3 +256,15 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(seq_concat_fc_fuse_pass, paddle::framework::ir::SeqConcatFcFusePass); +REGISTER_PASS_CAPABILITY(seq_concat_fc_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("sequence_expand", 0) + .EQ("concat", 0) + .EQ("mul", 0) + .EQ("elementwise_add", 0) + .EQ("sigmoid", 0) + .EQ("tanh", 0) + .EQ("relu", 0) + .EQ("identity", 0) + .EQ("fusion_seqexpand_concat_fc", 0)); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_seq_concat_fc_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_seq_concat_fc_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..33f215dafda21c68af3edb6baaeca802edf82c5a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_seq_concat_fc_fuse_pass.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker + + +class SeqConcatFCFusePassTest(InferencePassTest): + def test_compatible(self): + self.assertTrue( + PassVersionChecker.IsCompatible('seq_concat_fc_fuse_pass')) + + +if __name__ == "__main__": + unittest.main()