提交 e2e0a7f3 编写于 作者: Y yejianwu

support input with int32

上级 92f18fc6
...@@ -82,6 +82,8 @@ in one deployment file. ...@@ -82,6 +82,8 @@ in one deployment file.
- The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU. - The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type * - data_type
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP. - [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP.
* - input_data_types
- [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32.
* - limit_opencl_kernel_time * - limit_opencl_kernel_time
- [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0. - [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0.
* - obfuscate * - obfuscate
......
...@@ -194,7 +194,8 @@ struct StridedSliceFunctor { ...@@ -194,7 +194,8 @@ struct StridedSliceFunctor {
strides_data[2] > 0 ? k < real_end_indices[2] strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2]; : k > real_end_indices[2];
k += strides_data[2]) { k += strides_data[2]) {
*output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k]; *output_data++ =
input_data[(i * input->dim(1) + j) * input->dim(2) + k];
} }
} }
} }
......
...@@ -130,6 +130,16 @@ class RuntimeType(object): ...@@ -130,6 +130,16 @@ class RuntimeType(object):
cpu_gpu = 'cpu+gpu' cpu_gpu = 'cpu+gpu'
InputDataTypeStrs = [
"int32",
"float32",
]
InputDataType = Enum('InputDataType',
[(ele, ele) for ele in InputDataTypeStrs],
type=str)
CPUDataTypeStrs = [ CPUDataTypeStrs = [
"fp32", "fp32",
] ]
...@@ -183,6 +193,7 @@ class YAMLKeyword(object): ...@@ -183,6 +193,7 @@ class YAMLKeyword(object):
output_shapes = 'output_shapes' output_shapes = 'output_shapes'
runtime = 'runtime' runtime = 'runtime'
data_type = 'data_type' data_type = 'data_type'
input_data_types = 'input_data_types'
limit_opencl_kernel_time = 'limit_opencl_kernel_time' limit_opencl_kernel_time = 'limit_opencl_kernel_time'
nnlib_graph_mode = 'nnlib_graph_mode' nnlib_graph_mode = 'nnlib_graph_mode'
obfuscate = 'obfuscate' obfuscate = 'obfuscate'
...@@ -447,6 +458,18 @@ def format_model_config(flags): ...@@ -447,6 +458,18 @@ def format_model_config(flags):
if not isinstance(value, list): if not isinstance(value, list):
subgraph[key] = [value] subgraph[key] = [value]
input_data_types = subgraph.get(YAMLKeyword.input_data_types, "")
if input_data_types:
if not isinstance(input_data_types, list):
subgraph[YAMLKeyword.input_data_types] = [input_data_types]
for input_data_type in input_data_types:
mace_check(input_data_type in InputDataTypeStrs,
ModuleName.YAML_CONFIG,
"'input_data_types' must be in "
+ str(InputDataTypeStrs))
else:
subgraph[YAMLKeyword.input_data_types] = []
validation_threshold = subgraph.get( validation_threshold = subgraph.get(
YAMLKeyword.validation_threshold, {}) YAMLKeyword.validation_threshold, {})
if not isinstance(validation_threshold, dict): if not isinstance(validation_threshold, dict):
...@@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config, ...@@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
sh_commands.tuning_run( sh_commands.tuning_run(
abi=target_abi, abi=target_abi,
...@@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
...@@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi,
output_shapes=subgraphs[0][YAMLKeyword.output_shapes], output_shapes=subgraphs[0][YAMLKeyword.output_shapes],
model_output_dir=model_output_dir, model_output_dir=model_output_dir,
phone_data_dir=PHONE_DATA_DIR, phone_data_dir=PHONE_DATA_DIR,
input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa
caffe_env=flags.caffe_env, caffe_env=flags.caffe_env,
validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa
if flags.report and flags.round > 0: if flags.report and flags.round > 0:
...@@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num): ...@@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num):
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
runtime_list.extend([RuntimeType.cpu]) runtime_list.extend([RuntimeType.cpu])
......
...@@ -27,30 +27,37 @@ import common ...@@ -27,30 +27,37 @@ import common
# --input_ranges -1,1 # --input_ranges -1,1
def generate_data(name, shape, input_file, tensor_range): def generate_data(name, shape, input_file, tensor_range, input_data_type):
np.random.seed() np.random.seed()
data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \ data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \
+ tensor_range[0] + tensor_range[0]
input_file_name = common.formatted_file_name(input_file, name) input_file_name = common.formatted_file_name(input_file, name)
print 'Generate input file: ', input_file_name print 'Generate input file: ', input_file_name
data.astype(np.float32).tofile(input_file_name) if input_data_type == 'float32':
np_data_type = np.float32
elif input_data_type == 'int32':
np_data_type = np.int32
data.astype(np_data_type).tofile(input_file_name)
def generate_input_data(input_file, input_node, input_shape, input_ranges): def generate_input_data(input_file, input_node, input_shape, input_ranges,
input_data_type):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shapes = [shape for shape in input_shape.split(':')] input_shapes = [shape for shape in input_shape.split(':')]
if input_ranges: if input_ranges:
input_ranges = [r for r in input_ranges.split(':')] input_ranges = [r for r in input_ranges.split(':')]
else: else:
input_ranges = None input_ranges = [[-1, 1]] * len(input_names)
assert len(input_names) == len(input_shapes) if input_data_type:
input_data_types = [data_type
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa
for i in range(len(input_names)): for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')] shape = [int(x) for x in input_shapes[i].split(',')]
if input_ranges: generate_data(input_names[i], shape, input_file, input_ranges[i],
input_range = [float(x) for x in input_ranges[i].split(',')] input_data_types[i])
else:
input_range = [-1, 1]
generate_data(input_names[i], shape, input_file, input_range)
print "Generate input file done." print "Generate input file done."
...@@ -66,6 +73,8 @@ def parse_args(): ...@@ -66,6 +73,8 @@ def parse_args():
"--input_shape", type=str, default="1,64,64,3", help="input shape.") "--input_shape", type=str, default="1,64,64,3", help="input shape.")
parser.add_argument( parser.add_argument(
"--input_ranges", type=str, default="-1,1", help="input range.") "--input_ranges", type=str, default="-1,1", help="input range.")
parser.add_argument(
"--input_data_type", type=str, default="", help="input range.")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -73,4 +82,4 @@ def parse_args(): ...@@ -73,4 +82,4 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape, generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape,
FLAGS.input_ranges) FLAGS.input_ranges, FLAGS.input_data_type)
...@@ -536,6 +536,7 @@ def gen_random_input(model_output_dir, ...@@ -536,6 +536,7 @@ def gen_random_input(model_output_dir,
input_shapes, input_shapes,
input_files, input_files,
input_ranges, input_ranges,
input_data_types,
input_file_name="model_input"): input_file_name="model_input"):
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = common.formatted_file_name( formatted_name = common.formatted_file_name(
...@@ -545,10 +546,12 @@ def gen_random_input(model_output_dir, ...@@ -545,10 +546,12 @@ def gen_random_input(model_output_dir,
input_nodes_str = ",".join(input_nodes) input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes) input_shapes_str = ":".join(input_shapes)
input_ranges_str = ":".join(input_ranges) input_ranges_str = ":".join(input_ranges)
input_data_types_str = ",".join(input_data_types)
generate_input_data("%s/%s" % (model_output_dir, input_file_name), generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str, input_nodes_str,
input_shapes_str, input_shapes_str,
input_ranges_str) input_ranges_str,
input_data_types_str)
input_file_list = [] input_file_list = []
if isinstance(input_files, list): if isinstance(input_files, list):
...@@ -800,6 +803,7 @@ def validate_model(abi, ...@@ -800,6 +803,7 @@ def validate_model(abi,
output_shapes, output_shapes,
model_output_dir, model_output_dir,
phone_data_dir, phone_data_dir,
input_data_types,
caffe_env, caffe_env,
input_file_name="model_input", input_file_name="model_input",
output_file_name="model_out", output_file_name="model_out",
...@@ -821,7 +825,7 @@ def validate_model(abi, ...@@ -821,7 +825,7 @@ def validate_model(abi,
"%s/%s" % (model_output_dir, output_file_name), device_type, "%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold) validation_threshold, ",".join(input_data_types))
elif platform == "caffe": elif platform == "caffe":
image_name = "mace-caffe:latest" image_name = "mace-caffe:latest"
container_name = "mace_caffe_validator" container_name = "mace_caffe_validator"
......
...@@ -40,11 +40,13 @@ import common ...@@ -40,11 +40,13 @@ import common
VALIDATION_MODULE = 'VALIDATION' VALIDATION_MODULE = 'VALIDATION'
def load_data(file): def load_data(file, data_type='float32'):
if os.path.isfile(file): if os.path.isfile(file):
return np.fromfile(file=file, dtype=np.float32) if data_type == 'float32':
else: return np.fromfile(file=file, dtype=np.float32)
return np.empty([0]) elif data_type == 'int32':
return np.fromfile(file=file, dtype=np.int32)
return np.empty([0])
def compare_output(platform, device_type, output_name, mace_out_value, def compare_output(platform, device_type, output_name, mace_out_value,
...@@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name): ...@@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name):
def validate_tf_model(platform, device_type, model_file, input_file, def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold): output_names, validation_threshold, input_data_types):
import tensorflow as tf import tensorflow as tf
if not os.path.isfile(model_file): if not os.path.isfile(model_file):
common.MaceLogger.error( common.MaceLogger.error(
...@@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file, ...@@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file,
input_dict = {} input_dict = {}
for i in range(len(input_names)): for i in range(len(input_names)):
input_value = load_data( input_value = load_data(
common.formatted_file_name(input_file, input_names[i])) common.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i]) input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name( input_node = graph.get_tensor_by_name(
normalize_tf_tensor_name(input_names[i])) normalize_tf_tensor_name(input_names[i]))
...@@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file, ...@@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
def validate(platform, model_file, weight_file, input_file, mace_out_file, def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node, device_type, input_shape, output_shape, input_node, output_node,
validation_threshold): validation_threshold, input_data_type):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
for shape in input_shape_strs] for shape in input_shape_strs]
if input_data_type:
input_data_types = [data_type
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
output_names = [name for name in output_node.split(',')] output_names = [name for name in output_node.split(',')]
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
if platform == 'tensorflow': if platform == 'tensorflow':
validate_tf_model(platform, device_type, model_file, input_file, validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold) output_names, validation_threshold, input_data_types)
elif platform == 'caffe': elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')] output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')] output_shapes = [[int(x) for x in shape.split(',')]
...@@ -220,6 +228,11 @@ def parse_args(): ...@@ -220,6 +228,11 @@ def parse_args():
"--output_shape", type=str, default="1,64,64,2", help="output shape.") "--output_shape", type=str, default="1,64,64,2", help="output shape.")
parser.add_argument( parser.add_argument(
"--input_node", type=str, default="input_node", help="input node") "--input_node", type=str, default="input_node", help="input node")
parser.add_argument(
"--input_data_type",
type=str,
default="",
help="input data type")
parser.add_argument( parser.add_argument(
"--output_node", type=str, default="output_node", help="output node") "--output_node", type=str, default="output_node", help="output node")
parser.add_argument( parser.add_argument(
...@@ -241,4 +254,5 @@ if __name__ == '__main__': ...@@ -241,4 +254,5 @@ if __name__ == '__main__':
FLAGS.output_shape, FLAGS.output_shape,
FLAGS.input_node, FLAGS.input_node,
FLAGS.output_node, FLAGS.output_node,
FLAGS.validation_threshold) FLAGS.validation_threshold,
FLAGS.input_data_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册