From fc6b4a506dda4237be46aa06a85bb17e15ffdd96 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Fri, 13 Aug 2021 19:11:52 +0800 Subject: [PATCH] Bug fix : Can't load multiple modules of custom c++ op (#34505) * Fix a bug : can't load more than one custom op module * Fix a bug : can't load more than one custom op module * add test for load multiple modules of custom c++ op * add config for Coverage CI --- paddle/fluid/framework/custom_operator.cc | 8 +++++++- .../fluid/tests/custom_op/test_custom_relu_op_jit.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 7fef165f373..19e66158771 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -517,6 +517,12 @@ void RegisterOperatorWithMetaInfo( auto& base_op_meta = op_meta_infos.front(); auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta); + + if (OpInfoMap::Instance().Has(op_name)) { + LOG(WARNING) << "Operator (" << op_name << ")has been registered."; + return; + } + auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta); auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta); auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta); @@ -867,7 +873,7 @@ void RegisterOperatorWithMetaInfoMap( // load op api void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name); - + VLOG(1) << "load custom_op lib: " << dso_name; typedef OpMetaInfoMap& get_op_meta_info_map_t(); auto* get_op_meta_info_map = detail::DynLoad(handle, "PD_GetOpMetaInfoMap"); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 0f7ba84ffc1..052fe8b156a 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -130,6 +130,17 @@ class TestJITLoad(unittest.TestCase): str(e)) self.assertTrue(caught_exception) + def test_load_multiple_module(self): + custom_module = load( + name='custom_conj_jit', + sources=['custom_conj_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True) + custom_conj = custom_module.custom_conj + self.assertIsNotNone(custom_conj) + if __name__ == '__main__': unittest.main() -- GitLab