提交 529d916a 编写于 作者: 李滨

Merge branch 'pytorch_converter' into 'master'

pytorch converter

See merge request applied-machine-learning/sysml/mace!1300
......@@ -69,6 +69,8 @@ enum FrameworkType {
TENSORFLOW = 0,
CAFFE = 1,
ONNX = 2,
MEGENGINE = 3,
PYTORCH = 4
};
template <typename T>
......
......@@ -61,6 +61,7 @@ PlatformTypeStrs = [
"caffe",
"onnx",
"megengine",
"pytorch",
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
type=str)
......@@ -520,6 +521,13 @@ def format_model_config(flags):
if not isinstance(value, list):
subgraph[key] = [value]
subgraph[key] = [str(v) for v in subgraph[key]]
# --inputs_shapes will be passed to ELF file `mace_run_static', if input_shapes
# contains spaces, such as: '1, 3, 224, 224', because mace_run.cc use gflags to
# parse command line arguments, --input_shapes 1, 3, 224, 224 will be passed as
# `--input_shapes 1,'. So we strip out spaces here.
if key in [YAMLKeyword.input_shapes,
YAMLKeyword.output_shapes]:
subgraph[key] = [e.replace(' ', '') for e in subgraph[key]]
input_size = len(subgraph[YAMLKeyword.input_tensors])
output_size = len(subgraph[YAMLKeyword.output_tensors])
......
......@@ -632,6 +632,9 @@ class DeviceWrapper:
'Run model {} on {}'.format(model_name, self.device_name)))
model_config = configs[YAMLKeyword.models][model_name]
if model_config[YAMLKeyword.platform] == 'pytorch':
mace_check(flags.layers == "-1", "Device",
'extracting intermediate layer output is not supported in pytorch JIT yet') # noqa
model_runtime = model_config[YAMLKeyword.runtime]
subgraphs = model_config[YAMLKeyword.subgraphs]
......
......@@ -190,6 +190,10 @@ def convert_model(conf, quantize_stat):
from transform import megengine_converter
converter = megengine_converter.MegengineConverter(
option, conf["model_file_path"])
elif platform == Platform.PYTORCH:
from transform import pytorch_converter
converter = pytorch_converter.PytorchConverter(
option, conf["model_file_path"])
else:
mace_check(False, "Mace do not support platorm %s yet." % platform)
......
......@@ -88,6 +88,7 @@ class FrameworkType(Enum):
CAFFE = 1
ONNX = 2
MEGENGINE = 3
PYTORCH = 4
MaceSupportedOps = [
......
此差异已折叠。
......@@ -345,6 +345,7 @@ class Transformer(base_converter.ConverterInterface):
input_info.dims.extend(input_node.shape)
input_info.data_type = input_node.data_type
# tools/python/convert.py sets option.check_nodes
output_nodes = self._option.check_nodes.values()
for output_node in output_nodes:
output_info = net.output_info.add()
......@@ -1314,12 +1315,18 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op:
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# fc output is 2D in transformer, using as 4D in op kernel
# work for TensorFlow
# work for TensorFlow/PyTorch/ONNX
framework = ConverterUtil.get_arg(
op, MaceKeyword.mace_framework_type_str).i
is_torch = framework == FrameworkType.PYTORCH.value
is_tf = framework == FrameworkType.TENSORFLOW.value
is_onnx = framework == FrameworkType.ONNX.value
if op.type == MaceOp.Reshape.name and \
len(op.input) == 2 and \
op.input[1] in self._consts and \
len(op.output_shape[0].dims) == 2 and \
filter_format == DataFormat.HWIO and \
(is_tf or is_torch or is_onnx) and \
op.input[0] in self._producer:
input_op = self._producer[op.input[0]]
input_shape = input_op.output_shape[0].dims
......@@ -1334,8 +1341,13 @@ class Transformer(base_converter.ConverterInterface):
is_fc = False
else:
weight = self._consts[matmul_op.input[1]]
if len(weight.dims) != 2 or \
weight.dims[0] != op.output_shape[0].dims[1]:
od = op.output_shape[0].dims
wd = weight.dims
if len(wd) != 2:
is_fc = False
# tf fc weight: IO; onnx/pytorch fc weight: OI
if (is_tf and wd[0] != od[1]) or \
((is_torch or is_onnx) and wd[1] != od[1]):
is_fc = False
if is_fc:
print('convert reshape and matmul to fc')
......@@ -1346,24 +1358,40 @@ class Transformer(base_converter.ConverterInterface):
matmul_op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = input_shape[1:] + \
[weight_data.shape[1]]
if is_tf:
weight.dims[:] = input_shape[1:] + \
[weight_data.shape[1]]
if is_torch or is_onnx:
in_data_format = ConverterUtil.data_format(
input_op)
# OI+NCHW[2:]=OIHW
if in_data_format == DataFormat.NCHW:
weight.dims.extend(input_shape[2:])
# OI+NHWC[1:3]=OIHW
else:
weight.dims.extend(input_shape[1:3])
return True
# transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)`
if op.type == MaceOp.MatMul.name and \
filter_format == DataFormat.HWIO and \
(is_tf or is_torch or is_onnx) and \
op.input[1] in self._consts:
producer = self._producer[op.input[0]]
weight = self._consts[op.input[1]]
if len(weight.dims) == 2 and self.is_after_fc(op) and \
len(producer.output_shape[0].dims) == 2 and \
weight.dims[0] == producer.output_shape[0].dims[1]:
((is_tf and weight.dims[0] == producer.output_shape[0].dims[1]) or # noqa
(is_torch and weight.dims[1] == producer.output_shape[0].dims[1]) or # noqa
(is_onnx and weight.dims[1] == producer.output_shape[0].dims[1])): # noqa
six.print_('convert matmul to fc')
op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = [1, 1] + list(weight_data.shape)
# only 1 of the 2 branches can be executed
if is_tf:
weight.dims[:] = [1, 1] + list(weight_data.shape)
if is_torch or is_onnx:
weight.dims.extend([1, 1])
return True
if self._option.device == DeviceType.APU.value:
......@@ -2259,7 +2287,7 @@ class Transformer(base_converter.ConverterInterface):
dim_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str)
shape_tensor = None
if len(op.input) == 1:
print("Transform Caffe Reshape")
print("Transform Caffe or PyTorch Reshape")
dims = []
axis_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
# transform caffe reshape op
......
......@@ -151,6 +151,7 @@ class Platform(Enum):
CAFFE = 1
ONNX = 2
MEGENGINE = 3
PYTORCH = 4
def parse_platform(str):
......
......@@ -51,8 +51,8 @@ def execute(cmd, verbose=True):
print(line)
buf.append(line)
for l in p.stdout:
line = l.strip()
for li in p.stdout:
line = li.strip()
if verbose:
print(line)
buf.append(line)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import sys
import os.path
import numpy as np
import six
......@@ -204,6 +205,48 @@ def validate_tf_model(model_file,
validation_threshold, log_file)
def validate_pytorch_model(model_file,
input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types, log_file):
import torch
loaded_model = torch.jit.load(model_file)
pytorch_inputs = []
for i in range(len(input_names)):
input_value = load_data(
util.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
if input_data_formats[i] == DataFormat.NHWC and \
len(input_shapes[i]) == 4:
input_value = input_value.transpose((0, 3, 1, 2))
input_value = torch.from_numpy(input_value)
pytorch_inputs.append(input_value)
with torch.no_grad():
pytorch_outputs = loaded_model(*pytorch_inputs)
if isinstance(pytorch_outputs, torch.Tensor):
pytorch_outputs = [pytorch_outputs]
else:
if not isinstance(pytorch_outputs, (list, tuple)):
print('return type {} unsupported'.format(type(pytorch_outputs)))
sys.exit(1)
for i in range(len(output_names)):
value = pytorch_outputs[i].numpy()
output_file_name = util.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
# MACE: always returns tensor of dim 1
# pytorch: NCHW, conversion is needed
if output_data_formats[i] == DataFormat.NHWC and \
len(output_shapes[i]) == 4:
mace_out_value = mace_out_value.reshape(output_shapes[i])\
.transpose((0, 3, 1, 2))
compare_output(output_names[i], mace_out_value,
value, validation_threshold, log_file)
def validate_caffe_model(model_file, input_file,
mace_out_file, weight_file,
input_names, input_shapes, input_data_formats,
......@@ -387,6 +430,12 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_node, output_shape, output_data_format,
validation_threshold, input_data_type,
log_file)
elif platform == Platform.PYTORCH:
validate_pytorch_model(model_file, input_file, mace_out_file,
input_node, input_shape, input_data_format,
output_node, output_shape, output_data_format,
validation_threshold, input_data_type,
log_file)
elif platform == Platform.CAFFE:
validate_caffe_model(model_file,
input_file, mace_out_file, weight_file,
......
......@@ -53,7 +53,8 @@ def strip_invalid_utf8(str):
def split_stdout(stdout_str):
stdout_str = strip_invalid_utf8(stdout_str)
# Filter out last empty line
return [l.strip() for l in stdout_str.split('\n') if len(l.strip()) > 0]
return [line.strip() for line in stdout_str.split('\n') if
len(line.strip()) > 0]
def make_output_processor(buff):
......@@ -659,7 +660,7 @@ def validate_model(abi,
sh.rm("-rf", "%s/%s" % (model_output_dir, formatted_name))
device.pull_from_data_dir(formatted_name, model_output_dir)
if platform == "tensorflow" or platform == "onnx":
if platform == "tensorflow" or platform == "onnx" or platform == "pytorch":
validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), device_type,
......
......@@ -216,6 +216,48 @@ def validate_tf_model(platform, device_type, model_file,
validation_threshold, log_file)
def validate_pytorch_model(platform, device_type, model_file,
input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types, log_file):
import torch
loaded_model = torch.jit.load(model_file)
pytorch_inputs = []
for i in range(len(input_names)):
input_value = load_data(
common.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
if input_data_formats[i] == common.DataFormat.NHWC and \
len(input_shapes[i]) == 4:
input_value = input_value.transpose((0, 3, 1, 2))
input_value = torch.from_numpy(input_value)
pytorch_inputs.append(input_value)
with torch.no_grad():
pytorch_outputs = loaded_model(*pytorch_inputs)
if isinstance(pytorch_outputs, torch.Tensor):
pytorch_outputs = [pytorch_outputs]
else:
if not isinstance(pytorch_outputs, (list, tuple)):
print('return type {} unsupported yet'.format(
type(pytorch_outputs)))
sys.exit(1)
for i in range(len(output_names)):
value = pytorch_outputs[i].numpy()
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
# MACE: NHWC, pytorch: NCHW, conversion is needed
if output_data_formats[i] == common.DataFormat.NHWC and \
len(output_shapes[i]) == 4:
mace_out_value = mace_out_value.reshape(output_shapes[i])\
.transpose((0, 3, 1, 2))
compare_output(platform, device_type, output_names[i], mace_out_value,
value, validation_threshold, log_file)
def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file,
input_names, input_shapes, input_data_formats,
......@@ -418,6 +460,13 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types,
log_file)
elif platform == 'pytorch':
validate_pytorch_model(platform, device_type,
model_file, input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes,
output_data_formats, validation_threshold,
input_data_types, log_file)
elif platform == 'caffe':
validate_caffe_model(platform, device_type, model_file,
input_file, mace_out_file, weight_file,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册