未验证 提交 fc6b4a50 编写于 作者: Z zyfncg 提交者: GitHub

Bug fix : Can't load multiple modules of custom c++ op (#34505)

* Fix a bug : can't load more than one custom op module

* Fix a bug : can't load more than one custom op module

* add test for load multiple modules of custom c++ op

* add config for Coverage CI
上级 f421741c
...@@ -517,6 +517,12 @@ void RegisterOperatorWithMetaInfo( ...@@ -517,6 +517,12 @@ void RegisterOperatorWithMetaInfo(
auto& base_op_meta = op_meta_infos.front(); auto& base_op_meta = op_meta_infos.front();
auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta); auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta);
if (OpInfoMap::Instance().Has(op_name)) {
LOG(WARNING) << "Operator (" << op_name << ")has been registered.";
return;
}
auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta); auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta);
auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta); auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta);
auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta); auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta);
...@@ -867,7 +873,7 @@ void RegisterOperatorWithMetaInfoMap( ...@@ -867,7 +873,7 @@ void RegisterOperatorWithMetaInfoMap(
// load op api // load op api
void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name); void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name);
VLOG(1) << "load custom_op lib: " << dso_name;
typedef OpMetaInfoMap& get_op_meta_info_map_t(); typedef OpMetaInfoMap& get_op_meta_info_map_t();
auto* get_op_meta_info_map = auto* get_op_meta_info_map =
detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap"); detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap");
......
...@@ -130,6 +130,17 @@ class TestJITLoad(unittest.TestCase): ...@@ -130,6 +130,17 @@ class TestJITLoad(unittest.TestCase):
str(e)) str(e))
self.assertTrue(caught_exception) self.assertTrue(caught_exception)
def test_load_multiple_module(self):
custom_module = load(
name='custom_conj_jit',
sources=['custom_conj_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cc flags
extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags
verbose=True)
custom_conj = custom_module.custom_conj
self.assertIsNotNone(custom_conj)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册