diff --git a/python/akg/composite/build_module.py b/python/akg/composite/build_module.py index 369ef761b3a34176b0424c1ca09a2e837fddaf4e..a54d99a71182d587b0ae64179c77a40c0e3831e4 100644 --- a/python/akg/composite/build_module.py +++ b/python/akg/composite/build_module.py @@ -117,7 +117,7 @@ def _build_to_func(desc_s, desc_d, attr=None): return func(desc_s, attr) def _build(desc_s, desc_d, attr=None): - if desc_d['process'] == 'gpu': + if desc_d['process'] == 'cuda': func = tvm.get_global_func("composite_with_json") return func(desc_s, attr) rst = _build_to_func(desc_s, desc_d, attr) diff --git a/python/akg/ms/message.py b/python/akg/ms/message.py index 89f72767b420e0b5850fce4f5e8fe9ee1b6799c3..8ed443fc4d0dec2e2ba845dbb75ca26bba454882 100644 --- a/python/akg/ms/message.py +++ b/python/akg/ms/message.py @@ -41,19 +41,24 @@ def compilewithjson_to_func(json_str): logging.error(traceback.format_exc()) return False + processor = 'aicore' + if 'process' in kernel_info: + processor = kernel_info['process'] + if 'composite' in kernel_info and kernel_info['composite'] is True: try: - mod = composite._build_to_func(json_str, kernel_info) - return mod + if processor == 'cuda': + _ = composite._build(json_str, kernel_info) + return True + else: + mod = composite._build_to_func(json_str, kernel_info) + return mod except Exception: logging.error(traceback.format_exc()) return False op_name = kernel_info['name'] op_func = None - processor = 'aicore' - if 'process' in kernel_info: - processor = kernel_info['process'] # get custom ops implementation first. if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: impl_path = os.path.realpath(kernel_info['impl_path']) diff --git a/src/composite/composite.cc b/src/composite/composite.cc index 7ad315bed3f801e092b471c8398ff8947f114b99..0a7660648e36b6c6057295c76cd6d1219b61cd26 100644 --- a/src/composite/composite.cc +++ b/src/composite/composite.cc @@ -461,8 +461,8 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map attrs) { - if (get_process(json_str) == "gpu") { + if (get_process(json_str) == "cuda") { return composite_with_json_gpu(json_str, attrs); } auto build_rst = composite_with_json_to_func(json_str, attrs);