From e2e0a7f370be0485a55dd95973b738eed0e7f4cb Mon Sep 17 00:00:00 2001 From: yejianwu Date: Thu, 16 Aug 2018 21:11:49 +0800 Subject: [PATCH] support input with int32 --- docs/user_guide/advanced_usage.rst | 2 ++ mace/kernels/strided_slice.h | 3 ++- tools/converter.py | 33 +++++++++++++++++++++++++++--- tools/generate_data.py | 31 ++++++++++++++++++---------- tools/sh_commands.py | 8 ++++++-- tools/validate.py | 32 +++++++++++++++++++++-------- 6 files changed, 83 insertions(+), 26 deletions(-) diff --git a/docs/user_guide/advanced_usage.rst b/docs/user_guide/advanced_usage.rst index 8017f401..738f4448 100644 --- a/docs/user_guide/advanced_usage.rst +++ b/docs/user_guide/advanced_usage.rst @@ -82,6 +82,8 @@ in one deployment file. - The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU. * - data_type - [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP. + * - input_data_types + - [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32. * - limit_opencl_kernel_time - [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0. * - obfuscate diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h index e966367f..20c508aa 100644 --- a/mace/kernels/strided_slice.h +++ b/mace/kernels/strided_slice.h @@ -194,7 +194,8 @@ struct StridedSliceFunctor { strides_data[2] > 0 ? k < real_end_indices[2] : k > real_end_indices[2]; k += strides_data[2]) { - *output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k]; + *output_data++ = + input_data[(i * input->dim(1) + j) * input->dim(2) + k]; } } } diff --git a/tools/converter.py b/tools/converter.py index 941ea647..2e33f3be 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -130,6 +130,16 @@ class RuntimeType(object): cpu_gpu = 'cpu+gpu' +InputDataTypeStrs = [ + "int32", + "float32", +] + +InputDataType = Enum('InputDataType', + [(ele, ele) for ele in InputDataTypeStrs], + type=str) + + CPUDataTypeStrs = [ "fp32", ] @@ -183,6 +193,7 @@ class YAMLKeyword(object): output_shapes = 'output_shapes' runtime = 'runtime' data_type = 'data_type' + input_data_types = 'input_data_types' limit_opencl_kernel_time = 'limit_opencl_kernel_time' nnlib_graph_mode = 'nnlib_graph_mode' obfuscate = 'obfuscate' @@ -447,6 +458,18 @@ def format_model_config(flags): if not isinstance(value, list): subgraph[key] = [value] + input_data_types = subgraph.get(YAMLKeyword.input_data_types, "") + if input_data_types: + if not isinstance(input_data_types, list): + subgraph[YAMLKeyword.input_data_types] = [input_data_types] + for input_data_type in input_data_types: + mace_check(input_data_type in InputDataTypeStrs, + ModuleName.YAML_CONFIG, + "'input_data_types' must be in " + + str(InputDataTypeStrs)) + else: + subgraph[YAMLKeyword.input_data_types] = [] + validation_threshold = subgraph.get( YAMLKeyword.validation_threshold, {}) if not isinstance(validation_threshold, dict): @@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.validation_inputs_data], - input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) + input_ranges=subgraphs[0][YAMLKeyword.input_ranges], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types]) sh_commands.tuning_run( abi=target_abi, @@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.validation_inputs_data], - input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) + input_ranges=subgraphs[0][YAMLKeyword.input_ranges], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types]) runtime_list = [] if target_abi == ABIType.host: @@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi, output_shapes=subgraphs[0][YAMLKeyword.output_shapes], model_output_dir=model_output_dir, phone_data_dir=PHONE_DATA_DIR, + input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa caffe_env=flags.caffe_env, validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa if flags.report and flags.round > 0: @@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num): subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.validation_inputs_data], - input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) + input_ranges=subgraphs[0][YAMLKeyword.input_ranges], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types]) runtime_list = [] if target_abi == ABIType.host: runtime_list.extend([RuntimeType.cpu]) diff --git a/tools/generate_data.py b/tools/generate_data.py index d62297cc..1e485f20 100644 --- a/tools/generate_data.py +++ b/tools/generate_data.py @@ -27,30 +27,37 @@ import common # --input_ranges -1,1 -def generate_data(name, shape, input_file, tensor_range): +def generate_data(name, shape, input_file, tensor_range, input_data_type): np.random.seed() data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \ + tensor_range[0] input_file_name = common.formatted_file_name(input_file, name) print 'Generate input file: ', input_file_name - data.astype(np.float32).tofile(input_file_name) + if input_data_type == 'float32': + np_data_type = np.float32 + elif input_data_type == 'int32': + np_data_type = np.int32 + data.astype(np_data_type).tofile(input_file_name) -def generate_input_data(input_file, input_node, input_shape, input_ranges): +def generate_input_data(input_file, input_node, input_shape, input_ranges, + input_data_type): input_names = [name for name in input_node.split(',')] input_shapes = [shape for shape in input_shape.split(':')] if input_ranges: input_ranges = [r for r in input_ranges.split(':')] else: - input_ranges = None - assert len(input_names) == len(input_shapes) + input_ranges = [[-1, 1]] * len(input_names) + if input_data_type: + input_data_types = [data_type + for data_type in input_data_type.split(',')] + else: + input_data_types = ['float32'] * len(input_names) + assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa for i in range(len(input_names)): shape = [int(x) for x in input_shapes[i].split(',')] - if input_ranges: - input_range = [float(x) for x in input_ranges[i].split(',')] - else: - input_range = [-1, 1] - generate_data(input_names[i], shape, input_file, input_range) + generate_data(input_names[i], shape, input_file, input_ranges[i], + input_data_types[i]) print "Generate input file done." @@ -66,6 +73,8 @@ def parse_args(): "--input_shape", type=str, default="1,64,64,3", help="input shape.") parser.add_argument( "--input_ranges", type=str, default="-1,1", help="input range.") + parser.add_argument( + "--input_data_type", type=str, default="", help="input range.") return parser.parse_known_args() @@ -73,4 +82,4 @@ def parse_args(): if __name__ == '__main__': FLAGS, unparsed = parse_args() generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape, - FLAGS.input_ranges) + FLAGS.input_ranges, FLAGS.input_data_type) diff --git a/tools/sh_commands.py b/tools/sh_commands.py index d9b50342..ebef1c4a 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -536,6 +536,7 @@ def gen_random_input(model_output_dir, input_shapes, input_files, input_ranges, + input_data_types, input_file_name="model_input"): for input_name in input_nodes: formatted_name = common.formatted_file_name( @@ -545,10 +546,12 @@ def gen_random_input(model_output_dir, input_nodes_str = ",".join(input_nodes) input_shapes_str = ":".join(input_shapes) input_ranges_str = ":".join(input_ranges) + input_data_types_str = ",".join(input_data_types) generate_input_data("%s/%s" % (model_output_dir, input_file_name), input_nodes_str, input_shapes_str, - input_ranges_str) + input_ranges_str, + input_data_types_str) input_file_list = [] if isinstance(input_files, list): @@ -800,6 +803,7 @@ def validate_model(abi, output_shapes, model_output_dir, phone_data_dir, + input_data_types, caffe_env, input_file_name="model_input", output_file_name="model_out", @@ -821,7 +825,7 @@ def validate_model(abi, "%s/%s" % (model_output_dir, output_file_name), device_type, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), - validation_threshold) + validation_threshold, ",".join(input_data_types)) elif platform == "caffe": image_name = "mace-caffe:latest" container_name = "mace_caffe_validator" diff --git a/tools/validate.py b/tools/validate.py index 87bb3458..516cf512 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -40,11 +40,13 @@ import common VALIDATION_MODULE = 'VALIDATION' -def load_data(file): +def load_data(file, data_type='float32'): if os.path.isfile(file): - return np.fromfile(file=file, dtype=np.float32) - else: - return np.empty([0]) + if data_type == 'float32': + return np.fromfile(file=file, dtype=np.float32) + elif data_type == 'int32': + return np.fromfile(file=file, dtype=np.int32) + return np.empty([0]) def compare_output(platform, device_type, output_name, mace_out_value, @@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name): def validate_tf_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, - output_names, validation_threshold): + output_names, validation_threshold, input_data_types): import tensorflow as tf if not os.path.isfile(model_file): common.MaceLogger.error( @@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file, input_dict = {} for i in range(len(input_names)): input_value = load_data( - common.formatted_file_name(input_file, input_names[i])) + common.formatted_file_name(input_file, input_names[i]), + input_data_types[i]) input_value = input_value.reshape(input_shapes[i]) input_node = graph.get_tensor_by_name( normalize_tf_tensor_name(input_names[i])) @@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file, def validate(platform, model_file, weight_file, input_file, mace_out_file, device_type, input_shape, output_shape, input_node, output_node, - validation_threshold): + validation_threshold, input_data_type): input_names = [name for name in input_node.split(',')] input_shape_strs = [shape for shape in input_shape.split(':')] input_shapes = [[int(x) for x in shape.split(',')] for shape in input_shape_strs] + if input_data_type: + input_data_types = [data_type + for data_type in input_data_type.split(',')] + else: + input_data_types = ['float32'] * len(input_names) output_names = [name for name in output_node.split(',')] assert len(input_names) == len(input_shapes) if platform == 'tensorflow': validate_tf_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, - output_names, validation_threshold) + output_names, validation_threshold, input_data_types) elif platform == 'caffe': output_shape_strs = [shape for shape in output_shape.split(':')] output_shapes = [[int(x) for x in shape.split(',')] @@ -220,6 +228,11 @@ def parse_args(): "--output_shape", type=str, default="1,64,64,2", help="output shape.") parser.add_argument( "--input_node", type=str, default="input_node", help="input node") + parser.add_argument( + "--input_data_type", + type=str, + default="", + help="input data type") parser.add_argument( "--output_node", type=str, default="output_node", help="output node") parser.add_argument( @@ -241,4 +254,5 @@ if __name__ == '__main__': FLAGS.output_shape, FLAGS.input_node, FLAGS.output_node, - FLAGS.validation_threshold) + FLAGS.validation_threshold, + FLAGS.input_data_type) -- GitLab