diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 7fef165f3739699810514cf9eb8d57e0e7309a33..19e661587716b396ac1726b72ff483e5d1349d42 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -517,6 +517,12 @@ void RegisterOperatorWithMetaInfo( auto& base_op_meta = op_meta_infos.front(); auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta); + + if (OpInfoMap::Instance().Has(op_name)) { + LOG(WARNING) << "Operator (" << op_name << ")has been registered."; + return; + } + auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta); auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta); auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta); @@ -867,7 +873,7 @@ void RegisterOperatorWithMetaInfoMap( // load op api void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name); - + VLOG(1) << "load custom_op lib: " << dso_name; typedef OpMetaInfoMap& get_op_meta_info_map_t(); auto* get_op_meta_info_map = detail::DynLoad(handle, "PD_GetOpMetaInfoMap"); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 0f7ba84ffc147b75a5dbc29988263e3ff31b2d4c..052fe8b156a53cec04fc2a216768acc1d4528312 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -130,6 +130,17 @@ class TestJITLoad(unittest.TestCase): str(e)) self.assertTrue(caught_exception) + def test_load_multiple_module(self): + custom_module = load( + name='custom_conj_jit', + sources=['custom_conj_op.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cc flags + extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags + verbose=True) + custom_conj = custom_module.custom_conj + self.assertIsNotNone(custom_conj) + if __name__ == '__main__': unittest.main()