提交 e2e0a7f3 编写于 作者: Y yejianwu

support input with int32

上级 92f18fc6
......@@ -82,6 +82,8 @@ in one deployment file.
- The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP.
* - input_data_types
- [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32.
* - limit_opencl_kernel_time
- [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0.
* - obfuscate
......
......@@ -194,7 +194,8 @@ struct StridedSliceFunctor {
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
*output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k];
*output_data++ =
input_data[(i * input->dim(1) + j) * input->dim(2) + k];
}
}
}
......
......@@ -130,6 +130,16 @@ class RuntimeType(object):
cpu_gpu = 'cpu+gpu'
InputDataTypeStrs = [
"int32",
"float32",
]
InputDataType = Enum('InputDataType',
[(ele, ele) for ele in InputDataTypeStrs],
type=str)
CPUDataTypeStrs = [
"fp32",
]
......@@ -183,6 +193,7 @@ class YAMLKeyword(object):
output_shapes = 'output_shapes'
runtime = 'runtime'
data_type = 'data_type'
input_data_types = 'input_data_types'
limit_opencl_kernel_time = 'limit_opencl_kernel_time'
nnlib_graph_mode = 'nnlib_graph_mode'
obfuscate = 'obfuscate'
......@@ -447,6 +458,18 @@ def format_model_config(flags):
if not isinstance(value, list):
subgraph[key] = [value]
input_data_types = subgraph.get(YAMLKeyword.input_data_types, "")
if input_data_types:
if not isinstance(input_data_types, list):
subgraph[YAMLKeyword.input_data_types] = [input_data_types]
for input_data_type in input_data_types:
mace_check(input_data_type in InputDataTypeStrs,
ModuleName.YAML_CONFIG,
"'input_data_types' must be in "
+ str(InputDataTypeStrs))
else:
subgraph[YAMLKeyword.input_data_types] = []
validation_threshold = subgraph.get(
YAMLKeyword.validation_threshold, {})
if not isinstance(validation_threshold, dict):
......@@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config,
subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges])
input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
sh_commands.tuning_run(
abi=target_abi,
......@@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi,
subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges])
input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
runtime_list = []
if target_abi == ABIType.host:
......@@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi,
output_shapes=subgraphs[0][YAMLKeyword.output_shapes],
model_output_dir=model_output_dir,
phone_data_dir=PHONE_DATA_DIR,
input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa
caffe_env=flags.caffe_env,
validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa
if flags.report and flags.round > 0:
......@@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num):
subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges])
input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
runtime_list = []
if target_abi == ABIType.host:
runtime_list.extend([RuntimeType.cpu])
......
......@@ -27,30 +27,37 @@ import common
# --input_ranges -1,1
def generate_data(name, shape, input_file, tensor_range):
def generate_data(name, shape, input_file, tensor_range, input_data_type):
np.random.seed()
data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \
+ tensor_range[0]
input_file_name = common.formatted_file_name(input_file, name)
print 'Generate input file: ', input_file_name
data.astype(np.float32).tofile(input_file_name)
if input_data_type == 'float32':
np_data_type = np.float32
elif input_data_type == 'int32':
np_data_type = np.int32
data.astype(np_data_type).tofile(input_file_name)
def generate_input_data(input_file, input_node, input_shape, input_ranges):
def generate_input_data(input_file, input_node, input_shape, input_ranges,
input_data_type):
input_names = [name for name in input_node.split(',')]
input_shapes = [shape for shape in input_shape.split(':')]
if input_ranges:
input_ranges = [r for r in input_ranges.split(':')]
else:
input_ranges = None
assert len(input_names) == len(input_shapes)
input_ranges = [[-1, 1]] * len(input_names)
if input_data_type:
input_data_types = [data_type
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa
for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')]
if input_ranges:
input_range = [float(x) for x in input_ranges[i].split(',')]
else:
input_range = [-1, 1]
generate_data(input_names[i], shape, input_file, input_range)
generate_data(input_names[i], shape, input_file, input_ranges[i],
input_data_types[i])
print "Generate input file done."
......@@ -66,6 +73,8 @@ def parse_args():
"--input_shape", type=str, default="1,64,64,3", help="input shape.")
parser.add_argument(
"--input_ranges", type=str, default="-1,1", help="input range.")
parser.add_argument(
"--input_data_type", type=str, default="", help="input range.")
return parser.parse_known_args()
......@@ -73,4 +82,4 @@ def parse_args():
if __name__ == '__main__':
FLAGS, unparsed = parse_args()
generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape,
FLAGS.input_ranges)
FLAGS.input_ranges, FLAGS.input_data_type)
......@@ -536,6 +536,7 @@ def gen_random_input(model_output_dir,
input_shapes,
input_files,
input_ranges,
input_data_types,
input_file_name="model_input"):
for input_name in input_nodes:
formatted_name = common.formatted_file_name(
......@@ -545,10 +546,12 @@ def gen_random_input(model_output_dir,
input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes)
input_ranges_str = ":".join(input_ranges)
input_data_types_str = ",".join(input_data_types)
generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str,
input_shapes_str,
input_ranges_str)
input_ranges_str,
input_data_types_str)
input_file_list = []
if isinstance(input_files, list):
......@@ -800,6 +803,7 @@ def validate_model(abi,
output_shapes,
model_output_dir,
phone_data_dir,
input_data_types,
caffe_env,
input_file_name="model_input",
output_file_name="model_out",
......@@ -821,7 +825,7 @@ def validate_model(abi,
"%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold)
validation_threshold, ",".join(input_data_types))
elif platform == "caffe":
image_name = "mace-caffe:latest"
container_name = "mace_caffe_validator"
......
......@@ -40,10 +40,12 @@ import common
VALIDATION_MODULE = 'VALIDATION'
def load_data(file):
def load_data(file, data_type='float32'):
if os.path.isfile(file):
if data_type == 'float32':
return np.fromfile(file=file, dtype=np.float32)
else:
elif data_type == 'int32':
return np.fromfile(file=file, dtype=np.int32)
return np.empty([0])
......@@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name):
def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes,
output_names, validation_threshold):
output_names, validation_threshold, input_data_types):
import tensorflow as tf
if not os.path.isfile(model_file):
common.MaceLogger.error(
......@@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file,
input_dict = {}
for i in range(len(input_names)):
input_value = load_data(
common.formatted_file_name(input_file, input_names[i]))
common.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name(
normalize_tf_tensor_name(input_names[i]))
......@@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node,
validation_threshold):
validation_threshold, input_data_type):
input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')]
for shape in input_shape_strs]
if input_data_type:
input_data_types = [data_type
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
output_names = [name for name in output_node.split(',')]
assert len(input_names) == len(input_shapes)
if platform == 'tensorflow':
validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes,
output_names, validation_threshold)
output_names, validation_threshold, input_data_types)
elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
......@@ -220,6 +228,11 @@ def parse_args():
"--output_shape", type=str, default="1,64,64,2", help="output shape.")
parser.add_argument(
"--input_node", type=str, default="input_node", help="input node")
parser.add_argument(
"--input_data_type",
type=str,
default="",
help="input data type")
parser.add_argument(
"--output_node", type=str, default="output_node", help="output node")
parser.add_argument(
......@@ -241,4 +254,5 @@ if __name__ == '__main__':
FLAGS.output_shape,
FLAGS.input_node,
FLAGS.output_node,
FLAGS.validation_threshold)
FLAGS.validation_threshold,
FLAGS.input_data_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册