diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index 241972162493ca70936b466904f3336c849ee7c9..a38bd8cc6bc109333c3c56c12e9ba4d4eca69d13 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -39,11 +39,6 @@ FLAGS = None device_type_map = {'cpu': mace_pb2.CPU, 'gpu': mace_pb2.GPU, 'dsp': mace_pb2.HEXAGON} -device_data_type_map = { - mace_pb2.CPU: mace_pb2.DT_FLOAT, - mace_pb2.GPU: mace_pb2.DT_HALF, - mace_pb2.HEXAGON: mace_pb2.DT_UINT8 -} def file_checksum(fname): @@ -128,6 +123,17 @@ def main(unused_args): FLAGS.weight_file) output_graph_def = converter.run() + + if FLAGS.gpu_data_type == 'half': + gpu_data_type = mace_pb2.DT_HALF + else: + gpu_data_type = mace_pb2.DT_FLOAT + device_data_type_map = { + mace_pb2.CPU: mace_pb2.DT_FLOAT, + mace_pb2.GPU: gpu_data_type, + mace_pb2.HEXAGON: mace_pb2.DT_UINT8 + } + print("Transform model to one that can better run on device") if not FLAGS.runtime: cpu_graph_def = copy.deepcopy(output_graph_def) @@ -177,7 +183,8 @@ def main(unused_args): source_converter_lib.convert_to_source( output_graph_def, model_checksum, weight_checksum, FLAGS.template, FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime, - FLAGS.embed_model_data, FLAGS.winograd) + FLAGS.embed_model_data, FLAGS.winograd, + FLAGS.gpu_data_type) else: with open(FLAGS.output, "wb") as f: f.write(output_graph_def.SerializeToString()) @@ -266,6 +273,8 @@ def parse_args(): type=str2bool, default=True, help="embed model data.") + parser.add_argument( + "--gpu_data_type", type=str, default="half", help="half/float") return parser.parse_known_args() diff --git a/mace/python/tools/source_converter_lib.py b/mace/python/tools/source_converter_lib.py index 5b43e61b07f89e95b849fee56aea7bc3f83381af..2adcd383f515b7aae88ca81ca376e1e865855e4d 100644 --- a/mace/python/tools/source_converter_lib.py +++ b/mace/python/tools/source_converter_lib.py @@ -109,11 +109,11 @@ def rename_tensor(net_def): class TensorInfo: - def __init__(self, id, t, runtime): + def __init__(self, id, t, runtime, gpu_data_type): self.id = id self.data_type = mace_pb2.DataType.Name(t.data_type) if t.data_type == mace_pb2.DT_FLOAT: - if runtime == 'gpu': + if runtime == 'gpu' and gpu_data_type == 'half': self.data_type = mace_pb2.DT_HALF self.data = bytearray( np.array(t.float_data).astype(np.float16).tobytes()) @@ -137,7 +137,7 @@ def stringfy(value): def convert_to_source(net_def, model_checksum, weight_checksum, template_dir, obfuscate, model_tag, output, runtime, embed_model_data, - winograd_conv): + winograd_conv, gpu_data_type): if obfuscate: obfuscate_name(net_def) else: @@ -157,7 +157,7 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir, offset = 0 counter = 0 for t in net_def.tensors: - tensor_info = TensorInfo(counter, t, runtime) + tensor_info = TensorInfo(counter, t, runtime, gpu_data_type) # align if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0: padding = 4 - offset % 4 @@ -208,7 +208,7 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir, build_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') template_name = 'model.jinja2' tensors = [ - TensorInfo(i, net_def.tensors[i], runtime) + TensorInfo(i, net_def.tensors[i], runtime, gpu_data_type) for i in range(len(net_def.tensors)) ] checksum = model_checksum diff --git a/tools/mace_tools.py b/tools/mace_tools.py index 2af843a63f677e0e68c6ca3845ac872f5ece1862..af0597e132256fa9389a8e9f50c31ffb49f19ccd 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -523,6 +523,11 @@ def parse_args(): type=str, default="cpu", help="validation runtime.") + parser.add_argument( + "--gpu_data_type", + type=str, + default="half", + help="[half | float].") return parser.parse_known_args() @@ -778,7 +783,8 @@ def main(unused_args): model_config["dsp_mode"], embed_model_data, model_config["fast_conv"], - model_config["obfuscate"]) + model_config["obfuscate"], + FLAGS.gpu_data_type) for target_abi in configs["target_abis"]: for target_soc in target_socs: diff --git a/tools/sh_commands.py b/tools/sh_commands.py index da93e0ec7f07912254da35d016bab2bdb3110eef..f8463cc92c77d4f07b0c264f2a71413ec4e584cd 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -471,7 +471,8 @@ def gen_model_code(model_codegen_dir, dsp_mode, embed_model_data, fast_conv, - obfuscate): + obfuscate, + gpu_data_type): print("* Genearte model code") bazel_build_common("//mace/python/tools:converter") if os.path.exists(model_codegen_dir): @@ -497,6 +498,7 @@ def gen_model_code(model_codegen_dir, "--embed_model_data=%s" % embed_model_data, "--winograd=%s" % fast_conv, "--obfuscate=%s" % obfuscate, + "--gpu_data_type=%s" % gpu_data_type, _out=process_output, _bg=True, _err_to_out=True)