提交 a2957388 编写于 作者: W wenchunjiang

add param_type in json

上级 57c1da12
......@@ -15,6 +15,12 @@
"""tbe common"""
import json
import os
from attrdict import AttrDict
class ParamType(AttrDict):
Required = "required"
Dynamic = "dynamic"
Optional = "optional"
class TBEException(Exception):
......@@ -80,7 +86,62 @@ def _check_arg_info(item):
raise ValueError("Json string Errors, key:ori_format not found.")
if 'dtype' not in item or not item['dtype']:
raise ValueError("Json string Errors, key:dtype not found.")
if 'param_type' not in item or not item['param_type']:
raise ValueError("Json string Errors, key:param_type not found.")
def get_input_output(io_info, args):
"""
Parse args.
Args:
io_info (dict): input or output info dict.
args (list): the arguments list.
Raises:
Exception: If specific keyword is not found.
"""
for item in io_info:
arg = []
for info in item:
if 'valid' not in info:
raise ValueError("Json string Errors, key:valid not found.")
if info['valid']:
_check_arg_info(info)
del info['valid']
del info['name']
if len(item) > 1:
arg.append(info)
else:
if info['param_type'] == ParamType.Dynamic:
arg.append(info)
args.append(arg)
else:
args.append(info)
else:
if len(item) > 1:
arg.append(None)
else:
args.append(None)
if len(item) > 1:
args.append(arg)
def get_attr(attr_info, args):
"""
Parse args.
Args:
attr_info (dict): input or output info dict.
args (list): the arguments list.
Raises:
Exception: If specific keyword is not found.
"""
for item in attr_info:
if item["valid"]:
if 'value' not in item:
raise ValueError("Json string Errors, attr key:value not found.")
if item["name"] != "isRef":
args.append(item['value'])
def get_args(op_info, arg_type):
"""
......@@ -98,35 +159,12 @@ def get_args(op_info, arg_type):
args = []
if not op_info[arg_type]:
return args
if arg_type in ['inputs', 'outputs']:
for item in op_info[arg_type]:
arg = []
for info in item:
if 'valid' not in info:
raise ValueError("Json string Errors, key:valid not found.")
if info['valid']:
_check_arg_info(info)
del info['valid']
del info['name']
if len(item) > 1:
arg.append(info)
else:
args.append(info)
else:
if len(item) > 1:
arg.append(None)
else:
args.append(None)
if len(item) > 1:
args.append(arg)
arg_info = op_info[arg_type]
if arg_type in ['inputs', 'outputs']:
get_input_output(arg_info, args)
elif arg_type == 'attrs':
for item in op_info[arg_type]:
if item["valid"]:
if 'value' not in item:
raise ValueError("Json string Errors, attr key:value not found.")
if item["name"] != "isRef":
args.append(item['value'])
get_attr(arg_info, args)
return args
......
......@@ -147,6 +147,7 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
input_desc_json["format"] = format;
}
input_desc_json["valid"] = value;
input_desc_json["param_type"] = input_ptr->param_type();
input_list->emplace_back(input_desc_json);
}
return true;
......@@ -356,6 +357,7 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
output_obj["ori_format"] = kOpFormat_NCHW;
output_obj["name"] = output_ptr->name();
output_obj["valid"] = true;
output_obj["param_type"] = output_ptr->param_type();
output_list->emplace_back(output_obj);
(*output_idx)++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册