未验证 提交 f16ead9e 编写于 作者: J Jason 提交者: GitHub

Merge pull request #95 from Channingss/develop

support new model & fix bug
......@@ -137,14 +137,17 @@ def onnx2paddle(model_path, save_dir):
except:
print("onnx is not installed, use \"pip install onnx==1.5.0\".")
return
print("Now translating model from onnx to paddle.")
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
print("Now translating model from onnx to paddle.")
model = ONNXDecoder(model_path)
from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
mapper = ONNXOpMapper(model)
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
mapper.save_inference_model(save_dir)
......
此差异已折叠。
......@@ -23,6 +23,7 @@ from onnx.helper import get_attribute_value, make_attribute
from onnx.shape_inference import infer_shapes
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from onnx.numpy_helper import to_array
from onnx import AttributeProto, TensorProto, GraphProto
from collections import OrderedDict as Dict
import onnx
import numpy as np
......@@ -44,7 +45,7 @@ class ONNXGraphNode(GraphNode):
self.attr_map = self.get_attr_map()
self.dtype_map = {1: "float32", 3: "int32", 9: "int64"}
self.weight_inputs = list()
self.out_shapes = None
self.out_shapes = list()
self.dtype = None
def get_attr_map(self):
......@@ -58,11 +59,10 @@ class ONNXGraphNode(GraphNode):
@property
def value(self):
assert 'Constant' in self.layer_type, "Only Constant node has value."
attr = self.layer.attr['value']
if 'value' in self.attr_map:
return default
assert 'Constant' in self.layer_type, "Only Constant | ConstantOfShape node has value."
attr = self.layer.attribute['value']
if 'value' not in self.attr_map:
return None
return self.attr_map[name]
def get_attribute_value2(self, attr):
......@@ -110,23 +110,26 @@ class ONNXGraphDataNode(GraphNode):
def out_shapes(self):
values = self.layer.type.tensor_type.shape.dim
out_shapes = list()
out_shapes = [dim.dim_value for dim in values]
out_shapes.append([dim.dim_value for dim in values])
return out_shapes
@property
def dtype(self):
dtype = self.layer.type.tensor_type.elem_type
return TENSOR_TYPE_TO_NP_TYPE[dtype]
class ONNXGraph(Graph):
def __init__(self, model):
super(ONNXGraph, self).__init__(model)
def __init__(self, graph, onnx_model):
super(ONNXGraph, self).__init__(graph)
self.onnx_model = onnx_model
self.initializer = {}
self.place_holder_nodes = list()
self.get_place_holder_nodes()
self.value_infos = self.inferred_model_value_info(graph)
self.results_of_inference = dict()
def get_inner_nodes(self):
"""
generate inner node of ONNX model
......@@ -162,17 +165,22 @@ class ONNXGraph(Graph):
"""
build topo_sort of ONNX model
"""
data_node = self.place_holder_nodes[0]
value_info = self.value_infos[data_node]
input_shape = value_info['shape']
self.get_results_of_inference(self.onnx_model, input_shape)
for layer in self.model.node:
self.node_map[layer.name] = ONNXGraphNode(layer)
#set op node's dtype and out_shapes
for item in self.model.value_info:
if item.name in self.node_map:
self.node_map[item.name].dtype = TENSOR_TYPE_TO_NP_TYPE[
item.type.tensor_type.elem_type]
self.node_map[item.name].out_shapes = [
dim.dim_value for dim in item.type.tensor_type.shape.dim
]
node = ONNXGraphNode(layer)
self.node_map[layer.name] = node
for opt in layer.output:
if opt in self.value_infos:
value_info = self.value_infos[opt]
node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else:
_, dtype, shape = self.get_dynamic_shape(opt)
node.dtype = dtype
node.out_shapes.append(shape)
for layer in self.model.input:
if layer.name not in self.node_map:
......@@ -199,7 +207,6 @@ class ONNXGraph(Graph):
format(in_node, layer_name))
else:
self.connect(in_node, layer_name)
#generate topo
super(ONNXGraph, self).build()
......@@ -227,31 +234,108 @@ class ONNXGraph(Graph):
weight = to_array(initializer)
yield name, weight
def inferred_model_value_info(self, graph):
"""
collect value/type info for an ONNX model
"""
assert isinstance(graph,
onnx.GraphProto), 'model is not a ModelProto instance'
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': False
}
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
for item in graph.output:
assert item.name not in value_info
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
return value_info
def get_results_of_inference(self, model, shape):
try:
import torch
version = torch.__version__
if '1.1.0' not in version:
print("your model have dynamic graph, torch==1.1.0 is required")
return
except:
print(
"your model have dynamic graph, we use caff2 to inference graph, please use \"pip install torch==1.1.0\"."
)
return
from x2paddle.decoder.onnx_backend import prepare
np_images = np.random.rand(shape[0], shape[1], shape[2],
shape[3]).astype('float32')
outputs = []
for node in model.graph.node:
value_info = helper.make_tensor_value_info(node.name,
TensorProto.UNDEFINED,
[])
outputs.append(value_info)
while len(outputs) > 0:
tmp_outputs = outputs[:254]
model.graph.ClearField('output')
model.graph.output.MergeFrom(tmp_outputs)
prepared_backend = prepare(model,
device='CPU',
no_check_UNSAFE=True)
res = prepared_backend.run(inputs=np_images)
for idx, info in enumerate(tmp_outputs):
self.results_of_inference[info.name] = res[idx]
outputs = outputs[254:]
return
def get_dynamic_shape(self, layer):
"""
get dynamic shape from caffe2.backend
"""
output = self.results_of_inference[layer]
return output.tolist(), output.dtype, output.shape
class ONNXDecoder(object):
def __init__(self, onnx_model):
model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format(
model.ir_version, model.opset_import[0].version))
if model.opset_import[0].version < 9:
_logger.warning(
'Now, onnx2paddle main support convert onnx model opset_verison == 9,'
'opset_verison of your onnx model is %d < 9,'
'some operator may cannot convert.',
model.opset_import[0].version)
check_model(model)
model = polish_model(model)
check_model(model)
model = onnx.shape_inference.infer_shapes(model)
model = self.optimize_model_skip_op_for_inference(model)
model = self.optimize_model_strip_initializer(model)
self.standardize_variable_name(model.graph)
self.model = model
graph_def = model.graph
self.onnx_graph = ONNXGraph(graph_def)
self.onnx_graph = ONNXGraph(graph_def, model)
self.onnx_graph.build()
def build_value_refs(self, nodes):
......@@ -334,9 +418,13 @@ class ONNXDecoder(object):
output_name, output_refs)
else:
processed = -1
if processed > 0:
nodes_to_remove.append(node_idx)
for value_info in ret.graph.value_info:
for output in node.output:
if value_info.name == output:
ret.graph.value_info.remove(value_info)
print('skip op {}: {} -> {} -> {}'.format(
node_idx, input_name, node.op_type, output_name))
elif processed == 0:
......@@ -396,7 +484,6 @@ class ONNXDecoder(object):
"""
standardize variable name for paddle's code
"""
for initializer in graph.initializer:
initializer.name = self.make_variable_name(initializer.name)
for ipt in graph.input:
......@@ -455,43 +542,3 @@ class ONNXDecoder(object):
raise RuntimeError("Input mismatch {} != {}".format(
len(onnx_model.input), len(model.input)))
return onnx_model
def get_dynamic_shape_from_caffe2(self, layer, input_shapes):
"""
get dynamic shape from caffe2.backend
"""
try:
import torch
version = torch.__version__
if '1.1.0' not in version:
print("your model have dynamic graph, torch==1.1.0 is required")
return
except:
print(
"your model have dynamic graph, we use caff2 to inference graph, please use \"pip install torch==1.1.0\"."
)
return
from caffe2.python.onnx.backend import prepare
shape = input_shapes[0]
np_images = np.random.rand(shape[0], shape[1], shape[2],
shape[3]).astype('float32')
num_onnx = self.split_model(self.model, layer)
prepared_backend = prepare(num_onnx, device='CPU')
output = prepared_backend.run(inputs=np_images)
return output[0].tolist()
def get_dynamic_shape_from_onnx(self, layer, input_shapes):
"""
get dynamic shape from onnxruntime
"""
import onnxruntime as rt
from onnxruntime.backend import prepare
import numpy as np
num_onnx = self.split_model(self.model, layer)
sess = prepare(num_onnx)
shape = input_shapes[0]
print(shape)
np_images = np.random.rand(shape[0], shape[1], shape[2],
shape[3]).astype('float32')
output = sess.run(model=sess, inputs=np_images)
return output[0].tolist()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .register import register
def InstanceNormalization_shape(input_shape):
return input_shape
def InstanceNormalization_layer(inputs, name=None):
# TODO(lvmengsi@baidu.com): Check the accuracy when using fluid.layers.layer_norm.
epsilon = 1e-5
mean = fluid.layers.reduce_mean(inputs, dim=[2, 3], keep_dim=True)
var = fluid.layers.reduce_mean(fluid.layers.square(inputs - mean),
dim=[2, 3],
keep_dim=True)
if name is not None:
scale_name = name + "_scale"
offset_name = name + "_offset"
scale_param = fluid.ParamAttr(name=scale_name,
initializer=fluid.initializer.Constant(1.0),
trainable=True)
offset_param = fluid.ParamAttr(name=offset_name,
initializer=fluid.initializer.Constant(0.0),
trainable=True)
scale = fluid.layers.create_parameter(attr=scale_param,
shape=inputs.shape[1:2],
dtype="float32")
offset = fluid.layers.create_parameter(attr=offset_param,
shape=inputs.shape[1:2],
dtype="float32")
tmp = fluid.layers.elementwise_mul(x=(inputs - mean), y=scale, axis=1)
tmp = tmp / fluid.layers.sqrt(var + epsilon)
tmp = fluid.layers.elementwise_add(tmp, offset, axis=1)
return tmp
def InstanceNormalization_weights(name, data=None):
weights_name = [name + '_scale']
return weights_name
register(kind='InstanceNormalization',
shape=InstanceNormalization_shape,
layer=InstanceNormalization_layer,
weights=InstanceNormalization_weights)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .register import get_registered_layers
#custom layer import begins
from . import InstanceNormalization
#custom layer import ends
custom_layers = get_registered_layers()
def set_args(f, params):
""" set args for function 'f' using the parameters in node.layer.param
Args:
f (function): a python function object
params (object): a object contains attributes needed by f's arguments
Returns:
arg_names (list): a list of argument names
kwargs (dict): a dict contains needed arguments
"""
argc = f.__code__.co_argcount
arg_list = f.__code__.co_varnames[0:argc]
kwargs = {}
for arg_name in arg_list:
if hasattr(params, arg_name) and params is not None:
kwargs[arg_name] = getattr(params, arg_name)
return arg_list, kwargs
def has_layer(layer_type):
""" test whether this layer exists in custom layer
"""
return layer_type in custom_layers
def get_params(layer, layer_type):
import re
if layer_type.lower() == "deconvolution" or layer_type.lower(
) == "convolutiondepthwise":
param_name = '_'.join(('convolution', 'param'))
elif layer_type.lower() == "normalize":
param_name = '_'.join(('norm', 'param'))
elif len(layer_type) - len(re.sub("[A-Z]", "", layer_type)) >= 2:
s = ''
tmp_name = ''
for i, ch in enumerate(layer_type):
if i == 0:
s += ch.lower()
continue
elif ch.isupper() and layer_type[i - 1].islower():
tmp_name += (s + '_')
s = ''
s += ch.lower()
tmp_name += s
param_name = '_'.join((tmp_name, 'param'))
else:
param_name = '_'.join((layer_type.lower(), 'param'))
return getattr(layer, param_name, None)
def compute_output_shape(node):
""" compute the output shape of custom layer
"""
layer_type = node.layer_type
assert layer_type in custom_layers, "layer[%s] not exist in custom layers" % (
layer_type)
shape_func = custom_layers[layer_type]['shape']
layer = node.layer
params = get_params(layer, layer_type)
arg_names, kwargs = set_args(shape_func, params)
input_shape = node.input_shape
return shape_func(input_shape, **kwargs)
def make_custom_layer(node):
""" get the code which implement the custom layer function
"""
layer_type = node.layer_type
assert layer_type in custom_layers, "layer[%s] not exist in custom layers" % (
layer_type)
layer_func = custom_layers[layer_type]['layer']
import inspect
return inspect.getsource(layer_func), layer_func
def deal_weights(node, data=None):
""" deal the weights of the custom layer
"""
layer_type = node.layer_type
weights_func = custom_layers[layer_type]['weights']
name = node.layer_name
return weights_func(name, data)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" this module provides 'register' for registering customized layers
"""
g_custom_layers = {}
def register(kind, shape, layer, weights):
""" register a custom layer or a list of custom layers
Args:
@kind (str or list): type name of the layer
@shape (function): a function to generate the shape of layer's output
@layer (function): a function to generate the paddle code of layer
@weights (function): a function to deal with weights data
Returns:
None
"""
assert type(shape).__name__ == 'function', 'shape should be a function'
assert type(layer).__name__ == 'function', 'layer should be a function'
if type(kind) is str:
kind = [kind]
else:
assert type(
kind
) is list, 'invalid param "kind" for register, not a list or str'
for k in kind:
assert type(
k) is str, 'invalid param "kind" for register, not a list of str'
assert k not in g_custom_layers, 'this type[%s] has already been registered' % (
k)
print('register layer[%s]' % (k))
g_custom_layers[k] = {
'shape': shape,
'layer': layer,
'weights': weights
}
def get_registered_layers():
return g_custom_layers
......@@ -24,6 +24,7 @@ default_op_mapping_field_values['DEFAULTS'] = dict()
default_op_mapping_field_values['INPUT_PERM'] = None
default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = {
'Gather': ['gather', ['X'], ['Out'],
dict(axis='')],
......@@ -46,8 +47,44 @@ default_op_mapping = {
dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)
],
'ReduceSum': [
'reduce_sum', ['X'], ['Out'],
dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)
],
#active function
'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'],
dict(), dict(alpha=.01)],
'Elu': ['elu', ['X'], ['Out'],
dict(), dict(alpha=1.)],
'ThresholdedRelu': [
'thresholded_relu', ['X'], ['Out'],
dict(alpha='threshold'),
dict(alpha=1.)
],
'Tanh': ['tanh', ['X'], ['Out']],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'],
dict(),
dict(axis=-1)], # TODO: pow for scalar exponent
'HardSigmoid': [
'hard_sigmoid', ['X'], ['Out'],
dict(alpha='slope', beta='offset'),
dict(slope=.2, offset=.5)
],
'Softsign': ['softsign', ['X'], ['Out']],
'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'],
dict(axis=''),
dict(axis=1)],
}
activefunc_op_mapping = {
'LeakyRelu': ['leaky_relu', ['X'], ['Out'],
dict(), dict(alpha=.01)]
dict(), dict(alpha=.01)],
}
default_ioa_constraint = {
......
......@@ -14,7 +14,6 @@
# TODO useless node remove
from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
from x2paddle.core.util import *
class ONNXOptimizer(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册