未验证 提交 f827665a 编写于 作者: W Wilber 提交者: GitHub

[Pass Compatible] Bind python compatible. (#27262)

上级 bd77a425
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h" #include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -145,3 +146,11 @@ void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -145,3 +146,11 @@ void TransposeFlattenConcatFusePass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS(transpose_flatten_concat_fuse_pass, REGISTER_PASS(transpose_flatten_concat_fuse_pass,
paddle::framework::ir::TransposeFlattenConcatFusePass); paddle::framework::ir::TransposeFlattenConcatFusePass);
REGISTER_PASS_CAPABILITY(transpose_flatten_concat_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("transpose", 0)
.EQ("transpose2", 0)
.EQ("flatten", 0)
.EQ("concat", 0)
.EQ("fusion_transpose_flatten_concat", 0));
...@@ -38,6 +38,7 @@ set(PYBIND_SRCS ...@@ -38,6 +38,7 @@ set(PYBIND_SRCS
imperative.cc imperative.cc
ir.cc ir.cc
inference_api.cc inference_api.cc
compatible.cc
generator_py.cc) generator_py.cc)
if(WITH_GLOO) if(WITH_GLOO)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/compatible.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace py = pybind11;
using paddle::framework::compatible::PassVersionCheckerRegistrar;
namespace paddle {
namespace pybind {
void BindCompatible(py::module* m) {
py::class_<PassVersionCheckerRegistrar>(*m, "PassVersionChecker")
.def_static("IsCompatible", [](const std::string& name) -> bool {
auto instance = PassVersionCheckerRegistrar::GetInstance();
return instance.IsPassCompatible(name);
});
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/pybind11.h>
namespace paddle {
namespace pybind {
void BindCompatible(pybind11::module *m);
} // namespace pybind
} // namespace paddle
...@@ -60,6 +60,7 @@ limitations under the License. */ ...@@ -60,6 +60,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/box_helper_py.h" #include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/compatible.h"
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
...@@ -2619,6 +2620,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2619,6 +2620,7 @@ All parameter, weight, gradient are variables in Paddle.
BindGraph(&m); BindGraph(&m);
BindNode(&m); BindNode(&m);
BindInferenceApi(&m); BindInferenceApi(&m);
BindCompatible(&m);
BindDataset(&m); BindDataset(&m);
BindGenerator(&m); BindGenerator(&m);
#ifdef PADDLE_WITH_CRYPTO #ifdef PADDLE_WITH_CRYPTO
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from inference_pass_test import InferencePassTest from inference_pass_test import InferencePassTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
class TransposeFlattenConcatFusePassTest(InferencePassTest): class TransposeFlattenConcatFusePassTest(InferencePassTest):
...@@ -45,6 +46,37 @@ class TransposeFlattenConcatFusePassTest(InferencePassTest): ...@@ -45,6 +46,37 @@ class TransposeFlattenConcatFusePassTest(InferencePassTest):
use_gpu = True use_gpu = True
self.check_output_with_option(use_gpu) self.check_output_with_option(use_gpu)
PassVersionChecker.IsCompatible('transpose_flatten_concat_fuse_pass')
class TransposeFlattenConcatFusePassWithAxisTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(name="data1", shape=[5, 5, 5], dtype="float32")
data2 = fluid.data(name="data2", shape=[5, 5, 5], dtype="float32")
trans1 = fluid.layers.transpose(data1, perm=[2, 1, 0])
trans2 = fluid.layers.transpose(data2, perm=[2, 1, 0])
flatt1 = fluid.layers.flatten(trans1, axis=2)
flatt2 = fluid.layers.flatten(trans2, axis=2)
concat_out = fluid.layers.concat([flatt1, flatt2], axis=1)
# There is no parameters for above structure.
# Hence, append a batch_norm to avoid failure caused by load_combined.
out = fluid.layers.batch_norm(concat_out, is_test=True)
self.feeds = {
"data1": np.random.random([5, 5, 5]).astype("float32"),
"data2": np.random.random([5, 5, 5]).astype("float32")
}
self.fetch_list = [out]
def test_check_output(self):
# There is no cpu pass for transpose_flatten_concat_fuse
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
PassVersionChecker.IsCompatible('transpose_flatten_concat_fuse_pass')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册