diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc index 9a0a5f07a7080593d8f13e07788c703edb92c7ad..405cefa99ebbbe147fc96f63567e13607732780e 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -145,3 +146,11 @@ void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const { REGISTER_PASS(transpose_flatten_concat_fuse_pass, paddle::framework::ir::TransposeFlattenConcatFusePass); +REGISTER_PASS_CAPABILITY(transpose_flatten_concat_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("transpose", 0) + .EQ("transpose2", 0) + .EQ("flatten", 0) + .EQ("concat", 0) + .EQ("fusion_transpose_flatten_concat", 0)); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d733cf26ed209bcb86eaf2d366e45cfa0e7f9a90..92d9473141009216e3c7e64ccb793884dc67aadc 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -38,6 +38,7 @@ set(PYBIND_SRCS imperative.cc ir.cc inference_api.cc + compatible.cc generator_py.cc) if(WITH_GLOO) diff --git a/paddle/fluid/pybind/compatible.cc b/paddle/fluid/pybind/compatible.cc new file mode 100644 index 0000000000000000000000000000000000000000..971d230458db4bc2196ca529e01b0586da79567c --- /dev/null +++ b/paddle/fluid/pybind/compatible.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pybind/compatible.h" + +#include +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace py = pybind11; + +using paddle::framework::compatible::PassVersionCheckerRegistrar; + +namespace paddle { +namespace pybind { + +void BindCompatible(py::module* m) { + py::class_(*m, "PassVersionChecker") + .def_static("IsCompatible", [](const std::string& name) -> bool { + auto instance = PassVersionCheckerRegistrar::GetInstance(); + return instance.IsPassCompatible(name); + }); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/compatible.h b/paddle/fluid/pybind/compatible.h new file mode 100644 index 0000000000000000000000000000000000000000..f9d4cf5888fee8f62ce2e64636da6b98542b1a75 --- /dev/null +++ b/paddle/fluid/pybind/compatible.h @@ -0,0 +1,23 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace pybind { +void BindCompatible(pybind11::module *m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 4b8f7c853ceaf2148722a9c65f38e0ec3d9f4df5..330254ecaafd29c00e8942765956ea065d2bb7cf 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -60,6 +60,7 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/pybind/box_helper_py.h" +#include "paddle/fluid/pybind/compatible.h" #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" @@ -2619,6 +2620,7 @@ All parameter, weight, gradient are variables in Paddle. BindGraph(&m); BindNode(&m); BindInferenceApi(&m); + BindCompatible(&m); BindDataset(&m); BindGenerator(&m); #ifdef PADDLE_WITH_CRYPTO diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py index dfcd1758db2b22b211f84be528739aa71132ab8a..34a52e7aed342ac8db471ad94b277efd0faf9d27 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_transpose_flatten_concat_fuse_pass.py @@ -17,6 +17,7 @@ import numpy as np from inference_pass_test import InferencePassTest import paddle.fluid as fluid import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker class TransposeFlattenConcatFusePassTest(InferencePassTest): @@ -45,6 +46,37 @@ class TransposeFlattenConcatFusePassTest(InferencePassTest): use_gpu = True self.check_output_with_option(use_gpu) + PassVersionChecker.IsCompatible('transpose_flatten_concat_fuse_pass') + + +class TransposeFlattenConcatFusePassWithAxisTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data1 = fluid.data(name="data1", shape=[5, 5, 5], dtype="float32") + data2 = fluid.data(name="data2", shape=[5, 5, 5], dtype="float32") + trans1 = fluid.layers.transpose(data1, perm=[2, 1, 0]) + trans2 = fluid.layers.transpose(data2, perm=[2, 1, 0]) + flatt1 = fluid.layers.flatten(trans1, axis=2) + flatt2 = fluid.layers.flatten(trans2, axis=2) + concat_out = fluid.layers.concat([flatt1, flatt2], axis=1) + # There is no parameters for above structure. + # Hence, append a batch_norm to avoid failure caused by load_combined. + out = fluid.layers.batch_norm(concat_out, is_test=True) + + self.feeds = { + "data1": np.random.random([5, 5, 5]).astype("float32"), + "data2": np.random.random([5, 5, 5]).astype("float32") + } + self.fetch_list = [out] + + def test_check_output(self): + # There is no cpu pass for transpose_flatten_concat_fuse + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + + PassVersionChecker.IsCompatible('transpose_flatten_concat_fuse_pass') + if __name__ == "__main__": unittest.main()