未验证 提交 a28e8128 编写于 作者: Y yzchenmonkey 提交者: GitHub

Add MegEngine converter for MACE (#658)

上级 43b96415
......@@ -127,9 +127,9 @@ class StridedSliceOp : public Operation {
strides_data, strides_data + strides->size());
MACE_CHECK(input->size() > 0 && input->dim_size() > 0 &&
input->dim_size() <= 4,
input->dim_size() <= 5, // for megengine is 5, the others are 4
"The input size should larger than 0."
" And input dims should be an integer in (0, 4].");
" And input dims should be an integer in (0, 5].");
std::vector<index_t> output_shape = {};
......
......@@ -60,6 +60,7 @@ PlatformTypeStrs = [
"tensorflow",
"caffe",
"onnx",
"megengine",
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
type=str)
......
......@@ -220,8 +220,8 @@ class DeviceWrapper:
"MACE_LOG_TENSOR_RANGE=%d" % (1 if quantize_stat else 0),
"%s/%s" % (target_dir, target_name),
"--model_name=%s" % model_tag,
"--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes),
"--input_node='%s'" % ",".join(input_nodes),
"--output_node='%s'" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
......@@ -322,8 +322,8 @@ class DeviceWrapper:
cmd.extend([
"%s/%s" % (self.data_dir, target_name),
"--model_name=%s" % model_tag,
"--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes),
"--input_node='%s'" % ",".join(input_nodes),
"--output_node='%s'" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
......
......@@ -184,6 +184,10 @@ def convert_model(conf, quantize_stat):
from transform import onnx_converter
converter = onnx_converter.OnnxConverter(option,
conf["model_file_path"])
elif platform == Platform.MEGENGINE:
from transform import megengine_converter
converter = megengine_converter.MegengineConverter(
option, conf["model_file_path"])
else:
mace_check(False, "Mace do not support platorm %s yet." % platform)
......
......@@ -145,7 +145,7 @@ def run_model_for_device(flags, args, dev, model_name, model_conf):
"device": runtime.name
}
opts = ["--%s=%s" % (arg_key, arg_val) for arg_key, arg_val in
opts = ["--%s='%s'" % (arg_key, arg_val) for arg_key, arg_val in
model_args.items()] + args
should_generate_data = (flags.validate
or flags.tune or "--benchmark" in opts)
......
......@@ -86,6 +86,7 @@ class FrameworkType(Enum):
TENSORFLOW = 0
CAFFE = 1
ONNX = 2
MEGENGINE = 3
MaceSupportedOps = [
......@@ -547,7 +548,6 @@ class ConverterOption(object):
# Model structure related transformation
TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_FAKE_QUANTIZE,
TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
TransformerRule.TRANSFORM_BASIC_LSTMCELL,
......
此差异已折叠。
......@@ -149,6 +149,7 @@ class Platform(Enum):
TENSORFLOW = 0
CAFFE = 1
ONNX = 2
MEGENGINE = 3
def parse_platform(str):
......
......@@ -318,6 +318,51 @@ def validate_onnx_model(model_file,
mace_out_value, value,
validation_threshold, log_file)
def validate_megengine_model(model_file, input_file,
mace_out_file, input_names, input_shapes,
input_data_formats, output_names, output_shapes,
output_data_formats, validation_threshold,
input_data_types, log_file):
import megengine._internal as mgb
if not os.path.isfile(model_file):
common.MaceLogger.error(
VALIDATION_MODULE,
"Input graph file '" + model_file + "' does not exist!",
)
feed_inputs = []
for i in range(len(input_names)):
input_value = load_data(
util.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
if (input_data_formats[i] == DataFormat.NHWC and \
len(input_shapes[i]) == 4):
input_value = input_value.transpose((0, 3, 1, 2))
feed_inputs.append(input_value)
cg, _, outputs = mgb.load_comp_graph_from_file(model_file)
inputs = mgb.cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
inputs = sorted(inputs, key=lambda i: i.name)
outputs = list(map(mgb.copy_output, outputs))
if len(outputs) == 1:
(outputs,) = outputs
func = cg.compile(inputs, outputs)
mge_output_value = func(*feed_inputs)
for i in range(len(output_names)):
output_file_name = \
util.formatted_file_name(mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
if (output_data_formats[i] == DataFormat.NHWC and \
len(output_shapes[i]) == 4):
mace_out_value = \
mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2))
compare_output(output_names[i], mace_out_value,
mge_output_value, validation_threshold, log_file)
def validate(platform, model_file, weight_file, input_file, mace_out_file,
input_shape, output_shape, input_data_format,
......@@ -354,3 +399,12 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_node, output_shape, output_data_format,
validation_threshold,
input_data_type, backend, log_file)
elif platform == Platform.MEGENGINE:
validate_megengine_model(model_file,
input_file, mace_out_file,
input_node, input_shape,
input_data_format,
output_node, output_shape,
output_data_format,
validation_threshold,
input_data_type, log_file)
......@@ -748,8 +748,8 @@ def validate_model(abi,
"--input_file=/mace/%s" % input_file_name,
"--mace_out_file=/mace/%s" % output_file_name,
"--device_type=%s" % device_type,
"--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes),
"--input_node='%s'" % ",".join(input_nodes),
"--output_node='%s'" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
......@@ -761,6 +761,18 @@ def validate_model(abi,
validation_outputs_data),
"--log_file=%s" % log_file,
_fg=True)
elif platform == "megengine":
validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name),
device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_data_formats),
",".join(output_data_formats),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend,
validation_outputs_data,
log_file)
six.print_("Validation done!\n")
......
......@@ -331,6 +331,52 @@ def validate_onnx_model(platform, device_type, model_file,
validation_threshold, log_file)
def validate_megengine_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes,
input_data_formats, output_names, output_shapes,
output_data_formats, validation_threshold,
input_data_types, log_file):
import megengine._internal as mgb
if not os.path.isfile(model_file):
common.MaceLogger.error(
VALIDATION_MODULE,
"Input graph file '" + model_file + "' does not exist!",
)
feed_inputs = []
for i in range(len(input_names)):
input_value = load_data(
common.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
if (input_data_formats[i] == common.DataFormat.NHWC and \
len(input_shapes[i]) == 4):
input_value = input_value.transpose((0, 3, 1, 2))
feed_inputs.append(input_value)
cg, _, outputs = mgb.load_comp_graph_from_file(model_file)
inputs = mgb.cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
inputs = sorted(inputs, key=lambda i: i.name)
outputs = list(map(mgb.copy_output, outputs))
if len(outputs) == 1:
(outputs,) = outputs
func = cg.compile(inputs, outputs)
mge_output_value = func(*feed_inputs)
for i in range(len(output_names)):
output_file_name = \
common.formatted_file_name(mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
if (output_data_formats[i] == common.DataFormat.NHWC and \
len(output_shapes[i]) == 4):
mace_out_value = \
mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2))
compare_output(platform, device_type, output_names[i], mace_out_value,
mge_output_value, validation_threshold, log_file)
def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_data_format_str,
output_data_format_str, input_node, output_node,
......@@ -385,6 +431,15 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_names, output_shapes, output_data_formats,
validation_threshold,
input_data_types, backend, log_file)
elif platform == 'megengine':
validate_megengine_model(platform, device_type, model_file,
input_file, mace_out_file,
input_names, input_shapes,
input_data_formats,
output_names, output_shapes,
output_data_formats,
validation_threshold,
input_data_types, log_file)
def parse_args():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册