From a2957388300d176826cbb35fd18807ea483755f5 Mon Sep 17 00:00:00 2001 From: wenchunjiang Date: Mon, 25 May 2020 17:18:38 +0800 Subject: [PATCH] add param_type in json --- .../parallel_compile/tbe_compiler/common.py | 92 +++++++++++++------ .../ccsrc/kernel/tbe/tbe_kernel_build.cc | 2 + 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/common.py b/mindspore/_extends/parallel_compile/tbe_compiler/common.py index 39866d2ba..1aeba9889 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/common.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/common.py @@ -15,6 +15,12 @@ """tbe common""" import json import os +from attrdict import AttrDict + +class ParamType(AttrDict): + Required = "required" + Dynamic = "dynamic" + Optional = "optional" class TBEException(Exception): @@ -80,7 +86,62 @@ def _check_arg_info(item): raise ValueError("Json string Errors, key:ori_format not found.") if 'dtype' not in item or not item['dtype']: raise ValueError("Json string Errors, key:dtype not found.") + if 'param_type' not in item or not item['param_type']: + raise ValueError("Json string Errors, key:param_type not found.") + +def get_input_output(io_info, args): + """ + Parse args. + + Args: + io_info (dict): input or output info dict. + args (list): the arguments list. + + Raises: + Exception: If specific keyword is not found. + """ + for item in io_info: + arg = [] + for info in item: + if 'valid' not in info: + raise ValueError("Json string Errors, key:valid not found.") + if info['valid']: + _check_arg_info(info) + del info['valid'] + del info['name'] + if len(item) > 1: + arg.append(info) + else: + if info['param_type'] == ParamType.Dynamic: + arg.append(info) + args.append(arg) + else: + args.append(info) + else: + if len(item) > 1: + arg.append(None) + else: + args.append(None) + if len(item) > 1: + args.append(arg) +def get_attr(attr_info, args): + """ + Parse args. + + Args: + attr_info (dict): input or output info dict. + args (list): the arguments list. + + Raises: + Exception: If specific keyword is not found. + """ + for item in attr_info: + if item["valid"]: + if 'value' not in item: + raise ValueError("Json string Errors, attr key:value not found.") + if item["name"] != "isRef": + args.append(item['value']) def get_args(op_info, arg_type): """ @@ -98,35 +159,12 @@ def get_args(op_info, arg_type): args = [] if not op_info[arg_type]: return args - if arg_type in ['inputs', 'outputs']: - for item in op_info[arg_type]: - arg = [] - for info in item: - if 'valid' not in info: - raise ValueError("Json string Errors, key:valid not found.") - if info['valid']: - _check_arg_info(info) - del info['valid'] - del info['name'] - if len(item) > 1: - arg.append(info) - else: - args.append(info) - else: - if len(item) > 1: - arg.append(None) - else: - args.append(None) - if len(item) > 1: - args.append(arg) + arg_info = op_info[arg_type] + if arg_type in ['inputs', 'outputs']: + get_input_output(arg_info, args) elif arg_type == 'attrs': - for item in op_info[arg_type]: - if item["valid"]: - if 'value' not in item: - raise ValueError("Json string Errors, attr key:value not found.") - if item["name"] != "isRef": - args.append(item['value']) + get_attr(arg_info, args) return args diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index 1322c81d6..bd5b0d632 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -147,6 +147,7 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr &anf_node, input_desc_json["format"] = format; } input_desc_json["valid"] = value; + input_desc_json["param_type"] = input_ptr->param_type(); input_list->emplace_back(input_desc_json); } return true; @@ -356,6 +357,7 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr &anf_node, co output_obj["ori_format"] = kOpFormat_NCHW; output_obj["name"] = output_ptr->name(); output_obj["valid"] = true; + output_obj["param_type"] = output_ptr->param_type(); output_list->emplace_back(output_obj); (*output_idx)++; -- GitLab