提交 1400539d 编写于 作者: W WangChengke

mod gpu key to cuda

上级 5237a9a0
......@@ -117,7 +117,7 @@ def _build_to_func(desc_s, desc_d, attr=None):
return func(desc_s, attr)
def _build(desc_s, desc_d, attr=None):
if desc_d['process'] == 'gpu':
if desc_d['process'] == 'cuda':
func = tvm.get_global_func("composite_with_json")
return func(desc_s, attr)
rst = _build_to_func(desc_s, desc_d, attr)
......
......@@ -41,19 +41,24 @@ def compilewithjson_to_func(json_str):
logging.error(traceback.format_exc())
return False
processor = 'aicore'
if 'process' in kernel_info:
processor = kernel_info['process']
if 'composite' in kernel_info and kernel_info['composite'] is True:
try:
mod = composite._build_to_func(json_str, kernel_info)
return mod
if processor == 'cuda':
_ = composite._build(json_str, kernel_info)
return True
else:
mod = composite._build_to_func(json_str, kernel_info)
return mod
except Exception:
logging.error(traceback.format_exc())
return False
op_name = kernel_info['name']
op_func = None
processor = 'aicore'
if 'process' in kernel_info:
processor = kernel_info['process']
# get custom ops implementation first.
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
impl_path = os.path.realpath(kernel_info['impl_path'])
......
......@@ -461,8 +461,8 @@ NodeRef composite_with_json_to_func(const std::string &json_str, Map<std::string
std::string get_process(const std::string &json_str) {
size_t pos = json_str.find("\"process\"");
if (pos != std::string::npos && json_str.find("gpu", pos) != std::string::npos) {
return "gpu";
if (pos != std::string::npos && json_str.find("cuda", pos) != std::string::npos) {
return "cuda";
}
return "aicore";
}
......@@ -494,7 +494,7 @@ Module composite_with_json_gpu(const std::string &json_str, Map<std::string, Nod
}
Module composite_with_json(const std::string &json_str, Map<std::string, NodeRef> attrs) {
if (get_process(json_str) == "gpu") {
if (get_process(json_str) == "cuda") {
return composite_with_json_gpu(json_str, attrs);
}
auto build_rst = composite_with_json_to_func(json_str, attrs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册