提交 c66a5423 编写于 作者: C channingss

support model Face_cyclegan

上级 339aedd2
......@@ -44,7 +44,7 @@ class ONNXGraphNode(GraphNode):
self.attr_map = self.get_attr_map()
self.dtype_map = {1: "float32", 3: "int32", 9: "int64"}
self.weight_inputs = list()
self.out_shapes = None
self.out_shapes = list()
self.dtype = None
def get_attr_map(self):
......@@ -58,11 +58,11 @@ class ONNXGraphNode(GraphNode):
@property
def value(self):
assert 'Constant' in self.layer_type, "Only Constant node has value."
attr = self.layer.attr['value']
if 'value' in self.attr_map:
return default
assert 'Constant' in self.layer_type, "Only Constant | ConstantOfShape node has value."
print(self.layer)
attr = self.layer.attribute['value']
if 'value' not in self.attr_map:
return None
return self.attr_map[name]
def get_attribute_value2(self, attr):
......@@ -110,13 +110,12 @@ class ONNXGraphDataNode(GraphNode):
def out_shapes(self):
values = self.layer.type.tensor_type.shape.dim
out_shapes = list()
out_shapes = [dim.dim_value for dim in values]
out_shapes.append([dim.dim_value for dim in values])
return out_shapes
@property
def dtype(self):
dtype = self.layer.type.tensor_type.elem_type
return TENSOR_TYPE_TO_NP_TYPE[dtype]
......@@ -126,6 +125,7 @@ class ONNXGraph(Graph):
self.initializer = {}
self.place_holder_nodes = list()
self.get_place_holder_nodes()
self.value_infos = self.inferred_model_value_info(model)
def get_inner_nodes(self):
"""
......@@ -163,16 +163,12 @@ class ONNXGraph(Graph):
build topo_sort of ONNX model
"""
for layer in self.model.node:
self.node_map[layer.name] = ONNXGraphNode(layer)
#set op node's dtype and out_shapes
for item in self.model.value_info:
if item.name in self.node_map:
self.node_map[item.name].dtype = TENSOR_TYPE_TO_NP_TYPE[
item.type.tensor_type.elem_type]
self.node_map[item.name].out_shapes = [
dim.dim_value for dim in item.type.tensor_type.shape.dim
]
node = ONNXGraphNode(layer)
self.node_map[layer.name] = node
for opt in layer.output:
value_info = self.value_infos[opt]
node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
for layer in self.model.input:
if layer.name not in self.node_map:
......@@ -200,7 +196,9 @@ class ONNXGraph(Graph):
else:
self.connect(in_node, layer_name)
#generate topo
# print([layer_name for layer_name, node in self.node_map.items()])
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes
......@@ -227,6 +225,42 @@ class ONNXGraph(Graph):
weight = to_array(initializer)
yield name, weight
def inferred_model_value_info(self, graph):
"""
collect value/type info for an ONNX model
"""
assert isinstance(graph,
onnx.GraphProto), 'model is not a ModelProto instance'
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': False
}
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
for item in graph.output:
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
return value_info
class ONNXDecoder(object):
def __init__(self, onnx_model):
......@@ -241,7 +275,6 @@ class ONNXDecoder(object):
'some operator may cannot convert.',
model.opset_import[0].version)
check_model(model)
model = polish_model(model)
model = self.optimize_model_skip_op_for_inference(model)
......@@ -254,6 +287,8 @@ class ONNXDecoder(object):
self.onnx_graph = ONNXGraph(graph_def)
self.onnx_graph.build()
self.results_of_inference = dict()
def build_value_refs(self, nodes):
"""
build op reference of inputs and outputs
......@@ -456,10 +491,7 @@ class ONNXDecoder(object):
len(onnx_model.input), len(model.input)))
return onnx_model
def get_dynamic_shape_from_caffe2(self, layer, input_shapes):
"""
get dynamic shape from caffe2.backend
"""
def get_results_of_inference(self, model, input_shapes):
try:
import torch
version = torch.__version__
......@@ -472,26 +504,27 @@ class ONNXDecoder(object):
)
return
from caffe2.python.onnx.backend import prepare
shape = input_shapes[0]
np_images = np.random.rand(shape[0], shape[1], shape[2],
shape[3]).astype('float32')
num_onnx = self.split_model(self.model, layer)
prepared_backend = prepare(num_onnx, device='CPU')
infer_shapes = onnx.shape_inference.infer_shapes(model)
model.graph.ClearField('output')
model.graph.output.MergeFrom(infer_shapes.graph.value_info)
prepared_backend = prepare(model, device='CPU')
output = prepared_backend.run(inputs=np_images)
return output[0].tolist()
def get_dynamic_shape_from_onnx(self, layer, input_shapes):
for idx, value_info in enumerate(infer_shapes.graph.value_info):
self.results_of_inference[value_info.name] = output[idx]
return
def get_dynamic_shape_from_caffe2(self, layer, input_shapes):
"""
get dynamic shape from onnxruntime
get dynamic shape from caffe2.backend
"""
import onnxruntime as rt
from onnxruntime.backend import prepare
import numpy as np
num_onnx = self.split_model(self.model, layer)
sess = prepare(num_onnx)
shape = input_shapes[0]
print(shape)
np_images = np.random.rand(shape[0], shape[1], shape[2],
shape[3]).astype('float32')
output = sess.run(model=sess, inputs=np_images)
return output[0].tolist()
if len(self.results_of_inference) == 0:
self.get_results_of_inference(self.model, input_shapes)
output = self.results_of_inference[layer]
return output.tolist()
......@@ -24,6 +24,7 @@ default_op_mapping_field_values['DEFAULTS'] = dict()
default_op_mapping_field_values['INPUT_PERM'] = None
default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = {
'Gather': ['gather', ['X'], ['Out'],
dict(axis='')],
......@@ -47,7 +48,14 @@ default_op_mapping = {
dict(keep_dim=1)
],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'],
dict(), dict(alpha=.01)]
dict(), dict(alpha=.01)],
'Tanh': ['tanh', ['X'], ['Out']],
}
activefunc_op_mapping = {
'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'],
dict(), dict(alpha=.01)],
}
default_ioa_constraint = {
......
......@@ -22,6 +22,7 @@ from x2paddle.op_mapper.onnx_directly_map import default_op_mapping_field_values
from x2paddle.op_mapper.onnx_directly_map import default_op_mapping
from x2paddle.op_mapper.onnx_directly_map import default_ioa_constraint
import numpy as np
import onnx.numpy_helper as numpy_helper
import logging as _logging
from collections import OrderedDict as _dict
......@@ -66,6 +67,7 @@ class ONNXOpMapper(OpMapper):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
# print('translate{} layer_type is {}'.format(node_name, op))
if hasattr(self, op):
func = getattr(self, op)
func(node)
......@@ -134,10 +136,10 @@ class ONNXOpMapper(OpMapper):
param_attr=attr)
def place_holder(self, node):
self.input_shapes.append(node.out_shapes)
self.input_shapes.append(node.out_shapes[0])
attr = {
"dtype": string(node.dtype),
"shape": node.out_shapes,
"shape": node.out_shapes[0],
"name": string(node.layer_name),
"append_batch_size": 'False'
}
......@@ -151,7 +153,7 @@ class ONNXOpMapper(OpMapper):
if parameter is not None:
node = parameter
dtype = node.dtype
shape = node.out_shapes
shape = node.out_shapes[0]
self.weights[node.layer_name] = node.weight
attr = {
......@@ -184,8 +186,8 @@ class ONNXOpMapper(OpMapper):
pads = node.get_attr('pads')
mode = node.get_attr('mode', 'constant')
value = node.get_attr('value', 0.)
data_shape = val_x.out_shapes
output_shape = node.out_shapes
data_shape = val_x.out_shapes[0]
output_shape = node.out_shapes[0]
assume_pad2d = False
attr = {}
if len(pads) == 4:
......@@ -200,8 +202,6 @@ class ONNXOpMapper(OpMapper):
attr['mode'] = string(mode)
else:
attr = {'pad_value': value}
assert mode == 'constant', 'mode {} is supported only in pad2d'.format(
mode)
fluid_op = 'pad'
if len(pads) == 4:
paddings = np.array(pads).reshape(
......@@ -209,6 +209,10 @@ class ONNXOpMapper(OpMapper):
elif len(pads) == 8:
paddings = np.array(pads).reshape(
(-1, 4)).transpose().flatten().tolist() # SSEE -> SESE
if sum(paddings[:4]) == 0:
fluid_op = 'pad2d'
paddings = paddings[4:]
attr['mode'] = string(mode)
attr['paddings'] = paddings
if op_independent:
attr['name'] = string(node.layer_name)
......@@ -244,7 +248,7 @@ class ONNXOpMapper(OpMapper):
shape = node.get_attr('shape', None)
if shape is None:
shape = val_output.out_shapes
shape = val_output.out_shapes[0]
if shape is None:
shape = list(value.shape)
_logger.warning(
......@@ -271,7 +275,7 @@ class ONNXOpMapper(OpMapper):
val_scales = self.graph.get_node(node.layer.input[1], copy=True)
val_y, = self.graph.get_node(node.layer.output[0], copy=True)
out_shape_ = val_y.out_shapes
out_shape_ = val_y.out_shapes[0]
if out_shape_ is not None:
assert len(out_shape_) == 4, 'only 4-D Tensor as X and Y supported'
out_shape_ = out_shape_[2:]
......@@ -289,7 +293,7 @@ class ONNXOpMapper(OpMapper):
else:
out_shape = None
if out_shape_ is None:
in_shape = val_x.out_shapes
in_shape = val_x.out_shapes[0]
assert in_shape is not None, 'out_shape required but not inferrable'
assert len(
in_shape) == 4, 'only 4-D Tensor as X and Y supported'
......@@ -311,11 +315,11 @@ class ONNXOpMapper(OpMapper):
def ConstantOfShape(self, node):
val_shape = self.graph.get_node(node.layer.input[0], copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
shape = _const_weight_or_none(val_shape)
if shape is None:
shape = node.out_shapes
shape = node.out_shapes[0]
assert shape is not None, (
'given shape is neither const value nor deductible from output, '
......@@ -362,7 +366,7 @@ class ONNXOpMapper(OpMapper):
shape = self.decoder.get_dynamic_shape_from_caffe2(
val_shape.layer_name, self.input_shapes)
if shape is None:
shape = val_reshaped.out_shapes
shape = val_reshaped.out_shapes[0]
shape_dtype = val_shape.dtype
......@@ -417,7 +421,7 @@ class ONNXOpMapper(OpMapper):
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
input_shape = val_x.out_shapes
input_shape = val_x.out_shapes[0]
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
......@@ -572,6 +576,42 @@ class ONNXOpMapper(OpMapper):
output=node,
param_attr=attr)
def InstanceNormalization(self, node):
'''
y = scale * (x - mean) / sqrt(variance + epsilon) + B
'''
val_x = self.graph.get_node(node.layer.input[0], copy=True)
val_scale = self.graph.get_node(node.layer.input[1], copy=True)
val_b = self.graph.get_node(node.layer.input[2], copy=True)
epsilon = node.get_attr('epsilon', 1e-5)
num_out_channels = val_scale.out_shapes[0][0]
attr = {
"groups": num_out_channels,
"epsilon": epsilon,
"param_attr": string(val_scale.layer_name),
"bias_attr": string(val_b.layer_name),
"name": string(node.layer_name)
}
if val_scale.layer_type == 'Constant':
self.weights[val_scale.layer_name] = val_scale.get_attr('value')
if val_b.layer_type == 'Constant':
self.weights[val_b.layer_name] = val_b.get_attr('value')
# node_data_norm = node.layer_name +'data_norm'
node.fluid_code.add_layer("group_norm",
inputs=val_x,
output=node,
param_attr=attr)
# node.fluid_code.add_layer("elementwise_add",
# val_x.layer_name +','+ node_data_norm,
# output=node,
# param_attr=attr)
def Softmax(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True)
attr = {"name": string(node.layer_name)}
......@@ -610,12 +650,17 @@ class ONNXOpMapper(OpMapper):
def PRelu(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True)
val_slope = self.graph.get_node(node.layer.input[1], copy=True)
attr = {"name": string(node.layer_name), "mode": string('channel')}
if isinstance(val_slope, str):
attr["param_attr"] = string(val_slope.layer_name)
else:
attr["param_attr"] = string(val_slope.layer_name)
mode = 'channel'
shape_slope = val_slope.out_shapes[0]
if len(shape_slope) == 1:
mode = 'all'
elif len(shape_slope) > 2:
mode = 'element'
attr = {
"param_attr": string(val_slope.layer_name),
'mode': string(mode)
}
node.fluid_code.add_layer("prelu",
inputs=val_x,
output=node,
......@@ -651,7 +696,7 @@ class ONNXOpMapper(OpMapper):
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
input_shape = val_x.out_shapes
input_shape = val_x.out_shapes[0]
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
......@@ -676,8 +721,8 @@ class ONNXOpMapper(OpMapper):
def GlobalAveragePool(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
input_shape = val_x.out_shapes
output_shape = val_y.out_shapes
input_shape = val_x.out_shapes[0]
output_shape = val_y.out_shapes[0]
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # N
if input_shape:
poolnd = len(input_shape) - 2 # NC...
......@@ -701,7 +746,7 @@ class ONNXOpMapper(OpMapper):
val_y = self.graph.get_node(node.layer.output[0], copy=True)
self.omit_nodes.append(val_w.layer_name)
input_shape = val_x.out_shapes
input_shape = val_x.out_shapes[0]
has_bias = len(node.layer.input) == 3
if has_bias:
......@@ -709,12 +754,12 @@ class ONNXOpMapper(OpMapper):
self.omit_nodes.append(val_b.layer_name)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = val_w.out_shapes[2:] # OI...
kernel_shape = val_w.out_shapes[0][2:] # OI...
assert kernel_shape == node.get_attr(
'kernel_shape'), 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = val_w.out_shapes[0] # OI...
num_out_channels = val_w.out_shapes[0][0] # OI...
fluid_op = 'conv{}d'.format(convnd)
num_groups = node.get_attr('group', 1)
......@@ -749,3 +794,56 @@ class ONNXOpMapper(OpMapper):
inputs=val_x,
output=node,
param_attr=attr)
def ConvTranspose(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True)
val_w = self.graph.get_node(node.layer.input[1], copy=True)
val_b = self.graph.get_node(node.layer.input[2], copy=True)
self.omit_nodes.append(val_w.layer_name)
self.omit_nodes.append(val_b.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
out_padding = node.get_attr('output_padding', [0, 0])
kernel_shape = node.get_attr('kernel_shape', val_w.out_shapes[0][2:])
assert kernel_shape, 'kernel_shape not inferred'
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_out_channels = val_w.out_shapes[0][1] # IO...
fluid_op = 'conv{}d_transpose'.format(convnd)
num_groups = node.get_attr('group', 1) # optional
strides = node.get_attr('strides', [1] * convnd) # optional
dilations = node.get_attr('dilations', [1] * convnd) # optional
output_size = node.get_attr('output_shape', []) # optional
pads = node.get_attr('pads', [0] * (convnd * 2)) # optional
paddings, var_x = self._pad_if_asymmetric(node, pads, val_x)
output_size = [0, 0]
print(val_x.out_shapes[0])
output_size[0] = (val_x.out_shapes[0][2] -
1) * strides[0] - 2 * paddings[0] + dilations[0] * (
kernel_shape[0] - 1) + 1 + out_padding[0]
output_size[1] = (val_x.out_shapes[0][3] -
1) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1]
print(output_size)
attr = {
'num_filters': num_out_channels,
'output_size': output_size or None,
'filter_size': kernel_shape,
'padding': paddings,
'stride': strides,
'dilation': dilations,
'groups': num_groups,
'param_attr': string(val_w.layer_name),
'bias_attr': string(val_b.layer_name),
'name': string(node.layer_name),
}
node.fluid_code.add_layer(fluid_op,
inputs=val_x,
output=node,
param_attr=attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册