提交 d8d4a3f7 编写于 作者: 李寅

Merge branch 'support_gpu_float' into 'master'

support float data_type when running with gpu

See merge request !513
......@@ -39,11 +39,6 @@ FLAGS = None
device_type_map = {'cpu': mace_pb2.CPU,
'gpu': mace_pb2.GPU,
'dsp': mace_pb2.HEXAGON}
device_data_type_map = {
mace_pb2.CPU: mace_pb2.DT_FLOAT,
mace_pb2.GPU: mace_pb2.DT_HALF,
mace_pb2.HEXAGON: mace_pb2.DT_UINT8
}
def file_checksum(fname):
......@@ -128,6 +123,17 @@ def main(unused_args):
FLAGS.weight_file)
output_graph_def = converter.run()
if FLAGS.gpu_data_type == 'half':
gpu_data_type = mace_pb2.DT_HALF
else:
gpu_data_type = mace_pb2.DT_FLOAT
device_data_type_map = {
mace_pb2.CPU: mace_pb2.DT_FLOAT,
mace_pb2.GPU: gpu_data_type,
mace_pb2.HEXAGON: mace_pb2.DT_UINT8
}
print("Transform model to one that can better run on device")
if not FLAGS.runtime:
cpu_graph_def = copy.deepcopy(output_graph_def)
......@@ -177,7 +183,8 @@ def main(unused_args):
source_converter_lib.convert_to_source(
output_graph_def, model_checksum, weight_checksum, FLAGS.template,
FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime,
FLAGS.embed_model_data, FLAGS.winograd)
FLAGS.embed_model_data, FLAGS.winograd,
FLAGS.gpu_data_type)
else:
with open(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
......@@ -266,6 +273,8 @@ def parse_args():
type=str2bool,
default=True,
help="embed model data.")
parser.add_argument(
"--gpu_data_type", type=str, default="half", help="half/float")
return parser.parse_known_args()
......
......@@ -109,11 +109,11 @@ def rename_tensor(net_def):
class TensorInfo:
def __init__(self, id, t, runtime):
def __init__(self, id, t, runtime, gpu_data_type):
self.id = id
self.data_type = mace_pb2.DataType.Name(t.data_type)
if t.data_type == mace_pb2.DT_FLOAT:
if runtime == 'gpu':
if runtime == 'gpu' and gpu_data_type == 'half':
self.data_type = mace_pb2.DT_HALF
self.data = bytearray(
np.array(t.float_data).astype(np.float16).tobytes())
......@@ -137,7 +137,7 @@ def stringfy(value):
def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
obfuscate, model_tag, output, runtime, embed_model_data,
winograd_conv):
winograd_conv, gpu_data_type):
if obfuscate:
obfuscate_name(net_def)
else:
......@@ -157,7 +157,7 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
offset = 0
counter = 0
for t in net_def.tensors:
tensor_info = TensorInfo(counter, t, runtime)
tensor_info = TensorInfo(counter, t, runtime, gpu_data_type)
# align
if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0:
padding = 4 - offset % 4
......@@ -208,7 +208,7 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
build_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
template_name = 'model.jinja2'
tensors = [
TensorInfo(i, net_def.tensors[i], runtime)
TensorInfo(i, net_def.tensors[i], runtime, gpu_data_type)
for i in range(len(net_def.tensors))
]
checksum = model_checksum
......
......@@ -523,6 +523,11 @@ def parse_args():
type=str,
default="cpu",
help="validation runtime.")
parser.add_argument(
"--gpu_data_type",
type=str,
default="half",
help="[half | float].")
return parser.parse_known_args()
......@@ -778,7 +783,8 @@ def main(unused_args):
model_config["dsp_mode"],
embed_model_data,
model_config["fast_conv"],
model_config["obfuscate"])
model_config["obfuscate"],
FLAGS.gpu_data_type)
for target_abi in configs["target_abis"]:
for target_soc in target_socs:
......
......@@ -471,7 +471,8 @@ def gen_model_code(model_codegen_dir,
dsp_mode,
embed_model_data,
fast_conv,
obfuscate):
obfuscate,
gpu_data_type):
print("* Genearte model code")
bazel_build_common("//mace/python/tools:converter")
if os.path.exists(model_codegen_dir):
......@@ -497,6 +498,7 @@ def gen_model_code(model_codegen_dir,
"--embed_model_data=%s" % embed_model_data,
"--winograd=%s" % fast_conv,
"--obfuscate=%s" % obfuscate,
"--gpu_data_type=%s" % gpu_data_type,
_out=process_output,
_bg=True,
_err_to_out=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册