提交 324b75ee 编写于 作者: C channingss

fix bug & support new op for ssd

上级 b6e359f1
...@@ -89,6 +89,9 @@ def tf2paddle(model_path, save_dir): ...@@ -89,6 +89,9 @@ def tf2paddle(model_path, save_dir):
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
0
def caffe2paddle(proto, weight, save_dir, caffe_proto): def caffe2paddle(proto, weight, save_dir, caffe_proto):
from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
......
...@@ -17,7 +17,6 @@ from x2paddle.core.fluid_code import FluidCode ...@@ -17,7 +17,6 @@ from x2paddle.core.fluid_code import FluidCode
from onnx.checker import ValidationError from onnx.checker import ValidationError
from onnx.checker import check_model from onnx.checker import check_model
from onnx.utils import polish_model from onnx.utils import polish_model
from onnx.version_converter import convert_version
from onnx import helper from onnx import helper
from onnx.helper import get_attribute_value, make_attribute from onnx.helper import get_attribute_value, make_attribute
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
...@@ -26,6 +25,7 @@ from onnx.numpy_helper import to_array ...@@ -26,6 +25,7 @@ from onnx.numpy_helper import to_array
from onnx import AttributeProto, TensorProto, GraphProto from onnx import AttributeProto, TensorProto, GraphProto
from collections import OrderedDict as Dict from collections import OrderedDict as Dict
import onnx import onnx
from onnx.helper import ValueInfoProto
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
import logging as _logging import logging as _logging
...@@ -47,6 +47,7 @@ class ONNXGraphNode(GraphNode): ...@@ -47,6 +47,7 @@ class ONNXGraphNode(GraphNode):
self.weight_inputs = list() self.weight_inputs = list()
self.out_shapes = list() self.out_shapes = list()
self.dtype = None self.dtype = None
self.which_child = {}
def get_attr_map(self): def get_attr_map(self):
""" """
...@@ -60,10 +61,9 @@ class ONNXGraphNode(GraphNode): ...@@ -60,10 +61,9 @@ class ONNXGraphNode(GraphNode):
@property @property
def value(self): def value(self):
assert 'Constant' in self.layer_type, "Only Constant | ConstantOfShape node has value." assert 'Constant' in self.layer_type, "Only Constant | ConstantOfShape node has value."
attr = self.layer.attribute['value']
if 'value' not in self.attr_map: if 'value' not in self.attr_map:
return None return None
return self.attr_map[name] return self.attr_map['value']
def get_attribute_value2(self, attr): def get_attribute_value2(self, attr):
""" """
...@@ -105,18 +105,29 @@ class ONNXGraphDataNode(GraphNode): ...@@ -105,18 +105,29 @@ class ONNXGraphDataNode(GraphNode):
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.weight = None self.weight = None
self.embeded_as = None self.embeded_as = None
self.which_child = {}
@property @property
def out_shapes(self): def out_shapes(self):
values = self.layer.type.tensor_type.shape.dim if isinstance(self.layer, ValueInfoProto):
out_shapes = list() values = self.layer.type.tensor_type.shape.dim
out_shapes.append([dim.dim_value for dim in values]) out_shapes = list()
return out_shapes out_shapes.append([dim.dim_value for dim in values])
return out_shapes
else:
values = self.layer.dims
out_shapes = list()
out_shapes.append(values)
return out_shapes
@property @property
def dtype(self): def dtype(self):
dtype = self.layer.type.tensor_type.elem_type if isinstance(self.layer, ValueInfoProto):
return TENSOR_TYPE_TO_NP_TYPE[dtype] dtype = self.layer.type.tensor_type.elem_type
return TENSOR_TYPE_TO_NP_TYPE[dtype]
else:
dtype = self.layer.data_type
return TENSOR_TYPE_TO_NP_TYPE[dtype]
class ONNXGraph(Graph): class ONNXGraph(Graph):
...@@ -165,18 +176,23 @@ class ONNXGraph(Graph): ...@@ -165,18 +176,23 @@ class ONNXGraph(Graph):
""" """
build topo_sort of ONNX model build topo_sort of ONNX model
""" """
data_node = self.place_holder_nodes[0] data_nodes = self.place_holder_nodes
value_info = self.value_infos[data_node] self.get_results_of_inference_rt(self.onnx_model, data_nodes)
input_shape = value_info['shape']
self.get_results_of_inference(self.onnx_model, input_shape)
for layer in self.model.node: for layer in self.model.node:
node = ONNXGraphNode(layer) node = ONNXGraphNode(layer)
self.node_map[layer.name] = node self.node_map[layer.name] = node
for opt in layer.output: for opt in layer.output:
if opt in self.value_infos: if opt in self.value_infos:
value_info = self.value_infos[opt] value_info = self.value_infos[opt]
node.dtype = value_info['dtype'] if len(value_info['shape']
node.out_shapes.append(value_info['shape']) ) == 0 or value_info['dtype'] is None:
_, dtype, shape = self.get_dynamic_shape(opt)
node.dtype = dtype
node.out_shapes.append(shape)
else:
node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else: else:
_, dtype, shape = self.get_dynamic_shape(opt) _, dtype, shape = self.get_dynamic_shape(opt)
node.dtype = dtype node.dtype = dtype
...@@ -191,20 +207,40 @@ class ONNXGraph(Graph): ...@@ -191,20 +207,40 @@ class ONNXGraph(Graph):
is_global_input=is_place_holder) is_global_input=is_place_holder)
#set data node's weight #set data node's weight
for name, weight in self.graph_weights(self.model): for initializer in self.model.initializer:
name = initializer.name
weight = to_array(initializer)
if name in self.node_map: if name in self.node_map:
if isinstance(self.node_map[name], ONNXGraphDataNode): if isinstance(self.node_map[name], ONNXGraphDataNode):
self.node_map[name].weight = weight self.node_map[name].weight = weight
self.node_map[name].embeded_as = [] self.node_map[name].embeded_as = []
else:
self.node_map[name] = ONNXGraphDataNode(initializer,
layer_name=name,
is_global_input=False)
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
#generate connection between nodes for topo #generate connection between nodes for topo
for layer_name, node in self.node_map.items(): for layer_name, node in self.node_map.items():
if isinstance(node, ONNXGraphNode): if isinstance(node, ONNXGraphNode):
for idx, in_node in enumerate(node.layer.input): for idx, in_node in enumerate(node.layer.input):
if in_node not in self.node_map: if in_node not in self.node_map:
raise Exception( flag = 0
'input[{}] of node[{}] does not exist in node_map'. for nd in self.model.node:
format(in_node, layer_name)) for idx, opt in enumerate(nd.output):
if opt == in_node:
self.connect(nd.name, layer_name)
flag = 1
print(nd.name + '->' + layer_name)
node.which_child[nd.name] = idx
break
if flag == 1:
break
if flag == 0:
raise Exception(
'input[{}] of node[{}] does not exist in node_map'
.format(in_node, layer_name))
else: else:
self.connect(in_node, layer_name) self.connect(in_node, layer_name)
#generate topo #generate topo
...@@ -212,13 +248,14 @@ class ONNXGraph(Graph): ...@@ -212,13 +248,14 @@ class ONNXGraph(Graph):
self.input_nodes = self.place_holder_nodes self.input_nodes = self.place_holder_nodes
def get_nodes(self, names, copy=False): def get_input_node(self, node, idx=0, copy=False):
""" if len(node.which_child) == 0:
get nodes by more than one name return super(ONNXGraph, self).get_node(node.inputs[idx], copy)
""" else:
nodes = [] ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
for name in names: if ipt_node.layer_name in node.which_child:
nodes.add(self.get_node(name, copy=copy)) ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node
def graph_weights(self, graph): def graph_weights(self, graph):
""" """
...@@ -270,7 +307,7 @@ class ONNXGraph(Graph): ...@@ -270,7 +307,7 @@ class ONNXGraph(Graph):
} }
return value_info return value_info
def get_results_of_inference(self, model, shape): def get_results_of_inference(self, model, data_nodes):
try: try:
import torch import torch
version = torch.__version__ version = torch.__version__
...@@ -284,9 +321,11 @@ class ONNXGraph(Graph): ...@@ -284,9 +321,11 @@ class ONNXGraph(Graph):
return return
from x2paddle.decoder.onnx_backend import prepare from x2paddle.decoder.onnx_backend import prepare
np_images = np.random.rand(shape[0], shape[1], shape[2], inputs = []
shape[3]).astype('float32') for data_node in data_nodes:
value_info = self.value_infos[data_node]
ipt = np.random.random(value_info['shape']).astype('float32')
inputs.append(ipt)
outputs = [] outputs = []
for node in model.graph.node: for node in model.graph.node:
value_info = helper.make_tensor_value_info(node.name, value_info = helper.make_tensor_value_info(node.name,
...@@ -301,15 +340,46 @@ class ONNXGraph(Graph): ...@@ -301,15 +340,46 @@ class ONNXGraph(Graph):
prepared_backend = prepare(model, prepared_backend = prepare(model,
device='CPU', device='CPU',
no_check_UNSAFE=True) no_check_UNSAFE=True)
res = prepared_backend.run(inputs=np_images) res = prepared_backend.run(inputs=inputs)
for idx, info in enumerate(tmp_outputs): for idx, info in enumerate(tmp_outputs):
self.results_of_inference[info.name] = res[idx] self.results_of_inference[info.name] = res[idx]
outputs = outputs[254:] outputs = outputs[254:]
return return
def get_results_of_inference_rt(self, model, data_nodes):
import onnxruntime as rt
inputs = []
for data_node in data_nodes:
value_info = self.value_infos[data_node]
ipt = np.random.random(value_info['shape']).astype('float32')
inputs.append(ipt)
model = onnx.shape_inference.infer_shapes(model)
outputs = []
for value_info in model.graph.value_info:
outputs.append(value_info)
model.graph.ClearField('output')
model.graph.output.MergeFrom(outputs)
onnx.save(model, './onnx_model_infer.onnx')
sess = rt.InferenceSession('./onnx_model_infer.onnx')
inputs_dict = {}
for i, ipt in enumerate(inputs):
inputs_dict[sess.get_inputs()[i].name] = ipt
res = sess.run(None, input_feed=inputs_dict)
for idx, info in enumerate(outputs):
self.results_of_inference[info.name] = res[idx]
return
def get_dynamic_shape(self, layer): def get_dynamic_shape(self, layer):
""" """
get dynamic shape from caffe2.backend get dynamic shape from infer_result
""" """
output = self.results_of_inference[layer] output = self.results_of_inference[layer]
return output.tolist(), output.dtype, output.shape return output.tolist(), output.dtype, output.shape
...@@ -334,8 +404,8 @@ class ONNXDecoder(object): ...@@ -334,8 +404,8 @@ class ONNXDecoder(object):
self.standardize_variable_name(model.graph) self.standardize_variable_name(model.graph)
self.model = model self.model = model
graph_def = model.graph graph = model.graph
self.onnx_graph = ONNXGraph(graph_def, model) self.onnx_graph = ONNXGraph(graph, model)
self.onnx_graph.build() self.onnx_graph.build()
def build_value_refs(self, nodes): def build_value_refs(self, nodes):
...@@ -476,7 +546,7 @@ class ONNXDecoder(object): ...@@ -476,7 +546,7 @@ class ONNXDecoder(object):
if name == '': if name == '':
raise ValueError('name should not be empty') raise ValueError('name should not be empty')
for s in ' .*?\\/-:': # for s in ' .*?\\/-:':
name = name.replace(s, '_') name = name.replace(s, '_')
return '_' + name return '_' + name
...@@ -499,46 +569,3 @@ class ONNXDecoder(object): ...@@ -499,46 +569,3 @@ class ONNXDecoder(object):
node.input[i] = self.make_variable_name(node.input[i]) node.input[i] = self.make_variable_name(node.input[i])
for i in range(len(node.output)): for i in range(len(node.output)):
node.output[i] = self.make_variable_name(node.output[i]) node.output[i] = self.make_variable_name(node.output[i])
def split_model(self, model, outputs=None):
"""
Takes a model and changes its outputs.
"""
if outputs is None:
raise RuntimeError("outputs is None")
if outputs == model.graph.output[0].name:
return model
nodes = model.graph.node
keep_nodes = []
# all the nodes we need to keep.
for node in nodes:
if outputs in node.output:
keep_nodes.append(node)
break
keep_nodes.append(node)
infer_shapes = onnx.shape_inference.infer_shapes(model)
var_out = []
for value_info in infer_shapes.graph.value_info:
if value_info.name == outputs:
var_out.append(value_info)
break
graph = helper.make_graph(keep_nodes, model.graph.name,
model.graph.input, var_out,
model.graph.initializer)
onnx_model = helper.make_model(graph)
onnx_model.ir_version = model.ir_version
onnx_model.producer_name = model.producer_name
onnx_model.producer_version = model.producer_version
onnx_model.domain = model.domain
onnx_model.model_version = model.model_version
onnx_model.doc_string = model.doc_string
if len(onnx_model.graph.input) != len(model.graph.input):
raise RuntimeError("Input mismatch {} != {}".format(
len(onnx_model.input), len(model.input)))
return onnx_model
...@@ -22,7 +22,8 @@ def InstanceNormalization_shape(input_shape): ...@@ -22,7 +22,8 @@ def InstanceNormalization_shape(input_shape):
def InstanceNormalization_layer(inputs, name=None): def InstanceNormalization_layer(inputs, name=None):
# TODO(lvmengsi@baidu.com): Check the accuracy when using fluid.layers.layer_norm. # TODO(lvmengsi@baidu.com): Check the accuracy when using fluid.layers.layer_norm.
epsilon = 1e-5 epsilon = 1e-5
mean = fluid.layers.reduce_mean(inputs, dim=[2, 3], keep_dim=True) input_ = inputs[0]
mean = fluid.layers.reduce_mean(input_, dim=[2, 3], keep_dim=True)
var = fluid.layers.reduce_mean(fluid.layers.square(inputs - mean), var = fluid.layers.reduce_mean(fluid.layers.square(inputs - mean),
dim=[2, 3], dim=[2, 3],
keep_dim=True) keep_dim=True)
...@@ -36,13 +37,13 @@ def InstanceNormalization_layer(inputs, name=None): ...@@ -36,13 +37,13 @@ def InstanceNormalization_layer(inputs, name=None):
initializer=fluid.initializer.Constant(0.0), initializer=fluid.initializer.Constant(0.0),
trainable=True) trainable=True)
scale = fluid.layers.create_parameter(attr=scale_param, scale = fluid.layers.create_parameter(attr=scale_param,
shape=inputs.shape[1:2], shape=input_.shape[1:2],
dtype="float32") dtype="float32")
offset = fluid.layers.create_parameter(attr=offset_param, offset = fluid.layers.create_parameter(attr=offset_param,
shape=inputs.shape[1:2], shape=input_.shape[1:2],
dtype="float32") dtype="float32")
tmp = fluid.layers.elementwise_mul(x=(inputs - mean), y=scale, axis=1) tmp = fluid.layers.elementwise_mul(x=(input_ - mean), y=scale, axis=1)
tmp = tmp / fluid.layers.sqrt(var + epsilon) tmp = tmp / fluid.layers.sqrt(var + epsilon)
tmp = fluid.layers.elementwise_add(tmp, offset, axis=1) tmp = fluid.layers.elementwise_add(tmp, offset, axis=1)
return tmp return tmp
...@@ -56,4 +57,5 @@ def InstanceNormalization_weights(name, data=None): ...@@ -56,4 +57,5 @@ def InstanceNormalization_weights(name, data=None):
register(kind='InstanceNormalization', register(kind='InstanceNormalization',
shape=InstanceNormalization_shape, shape=InstanceNormalization_shape,
layer=InstanceNormalization_layer, layer=InstanceNormalization_layer,
child_func=None,
weights=InstanceNormalization_weights) weights=InstanceNormalization_weights)
...@@ -16,6 +16,7 @@ from .register import get_registered_layers ...@@ -16,6 +16,7 @@ from .register import get_registered_layers
#custom layer import begins #custom layer import begins
from . import InstanceNormalization from . import InstanceNormalization
from . import NonMaxSuppression
#custom layer import ends #custom layer import ends
custom_layers = get_registered_layers() custom_layers = get_registered_layers()
...@@ -95,6 +96,17 @@ def make_custom_layer(node): ...@@ -95,6 +96,17 @@ def make_custom_layer(node):
return inspect.getsource(layer_func), layer_func return inspect.getsource(layer_func), layer_func
def make_custom_child_func(node):
""" get the code which implement the custom layer function
"""
layer_type = node.layer_type
assert layer_type in custom_layers, "layer[%s] not exist in custom layers" % (
layer_type)
child_func = custom_layers[layer_type]['child_func']
import inspect
return inspect.getsource(child_func), child_func
def deal_weights(node, data=None): def deal_weights(node, data=None):
""" deal the weights of the custom layer """ deal the weights of the custom layer
""" """
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
g_custom_layers = {} g_custom_layers = {}
def register(kind, shape, layer, weights): def register(kind, shape, layer, child_func, weights):
""" register a custom layer or a list of custom layers """ register a custom layer or a list of custom layers
Args: Args:
...@@ -48,6 +48,7 @@ def register(kind, shape, layer, weights): ...@@ -48,6 +48,7 @@ def register(kind, shape, layer, weights):
g_custom_layers[k] = { g_custom_layers[k] = {
'shape': shape, 'shape': shape,
'layer': layer, 'layer': layer,
'child_func': child_func,
'weights': weights 'weights': weights
} }
......
...@@ -32,6 +32,9 @@ default_op_mapping = { ...@@ -32,6 +32,9 @@ default_op_mapping = {
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], 'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'],
dict(), dict(),
dict(axis=-1)], dict(axis=-1)],
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'],
dict(),
dict(axis=-1)],
'Clip': [ 'Clip': [
'clip', ['X'], ['Out'], 'clip', ['X'], ['Out'],
dict(), dict(),
...@@ -42,6 +45,7 @@ default_op_mapping = { ...@@ -42,6 +45,7 @@ default_op_mapping = {
dtype=_np.uint8).view(_np.float32)), dtype=_np.uint8).view(_np.float32)),
) )
], ],
'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [ 'ReduceMean': [
'reduce_mean', ['X'], ['Out'], 'reduce_mean', ['X'], ['Out'],
dict(axes='dim', keepdims='keep_dim'), dict(axes='dim', keepdims='keep_dim'),
...@@ -52,7 +56,11 @@ default_op_mapping = { ...@@ -52,7 +56,11 @@ default_op_mapping = {
dict(axes='dim', keepdims='keep_dim'), dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1) dict(keep_dim=1)
], ],
'ReduceMin': [
'reduce_min', ['X'], ['Out'],
dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)
],
#active function #active function
'Relu': ['relu', ['X'], ['Out']], 'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], 'LeakyRelu': ['leaky_relu', ['X'], ['Out'],
...@@ -78,8 +86,7 @@ default_op_mapping = { ...@@ -78,8 +86,7 @@ default_op_mapping = {
'Softplus': ['softplus', ['X'], ['Out']], 'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']], 'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], 'Softmax': ['softmax', ['X'], ['Out'],
dict(axis=''), dict(), dict(axis=1)],
dict(axis=1)],
} }
activefunc_op_mapping = { activefunc_op_mapping = {
......
...@@ -24,15 +24,17 @@ from x2paddle.op_mapper.onnx_custom_layer import * ...@@ -24,15 +24,17 @@ from x2paddle.op_mapper.onnx_custom_layer import *
from x2paddle.core.util import string from x2paddle.core.util import string
import numpy as np import numpy as np
import onnx.numpy_helper as numpy_helper import onnx.numpy_helper as numpy_helper
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
import logging as _logging import logging as _logging
from collections import OrderedDict as _dict from collections import OrderedDict as _dict
import math
_logger = _logging.getLogger(__name__) _logger = _logging.getLogger(__name__)
def _const_weight_or_none(node): def _const_weight_or_none(node):
if 'Constant' in node.layer_name: if 'Constant' in node.layer_name:
return val.value return node.value
if isinstance(node, ONNXGraphDataNode): if isinstance(node, ONNXGraphDataNode):
return node.weight return node.weight
return None return None
...@@ -94,7 +96,7 @@ class ONNXOpMapper(OpMapper): ...@@ -94,7 +96,7 @@ class ONNXOpMapper(OpMapper):
print(op) print(op)
return False return False
def directly_map(self, node, *args, name='', **kwargs): def directly_map(self, node, name='', *args, **kwargs):
inputs = node.layer.input inputs = node.layer.input
outputs = node.layer.output outputs = node.layer.output
op_type = node.layer_type op_type = node.layer_type
...@@ -127,34 +129,38 @@ class ONNXOpMapper(OpMapper): ...@@ -127,34 +129,38 @@ class ONNXOpMapper(OpMapper):
mapped_attrs.pop('_') mapped_attrs.pop('_')
fluid_attrs = default_attrs.copy() fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs) fluid_attrs.update(mapped_attrs)
val_inps = inputs if input_perm is None else list( inputs = inputs if input_perm is None else list(
map(lambda i: inputs[i], input_perm)) map(lambda i: inputs[i], input_perm))
val_inps = []
for idx, ipt in enumerate(inputs):
val_inps.append(self.graph.get_input_node(node, idx=idx, copy=True))
val_outs = outputs if output_perm is None else list( val_outs = outputs if output_perm is None else list(
map(lambda i: outputs[i], output_perm)) map(lambda i: outputs[i], output_perm))
attr = fluid_attrs attr = fluid_attrs
if fluid_op not in ['shape', 'gather']: if fluid_op not in ['shape', 'gather']:
attr['name'] = string(node.layer_name) attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(fluid_op, node.fluid_code.add_layer(fluid_op,
inputs=', '.join(val_inps), inputs=val_inps,
output=val_outs[0], output=val_outs[0],
param_attr=attr) param_attr=attr)
def deal_custom_layer(self, node): def deal_custom_layer(self, node):
op = node.layer_type op = node.layer_type
val_x = self.graph.get_node(node.layer.input[0], copy=True)
custom_code, func = make_custom_layer(node) custom_code, func = make_custom_layer(node)
child_func_code, child_func = make_custom_child_func(node)
params = get_params(node.layer, node.layer_type) params = get_params(node.layer, node.layer_type)
arg_names, kwargs = set_args(func, params) arg_names, kwargs = set_args(func, params)
kwargs['name'] = string(node.layer_name) kwargs['name'] = string(node.layer_name)
inputs_node = []
inputs_node.append(node.inputs[0])
node.fluid_code.add_layer(func.__code__.co_name, node.fluid_code.add_layer(func.__code__.co_name,
inputs=inputs_node[0], inputs=node.inputs,
output=node, output=node,
param_attr=kwargs, param_attr=kwargs,
is_custom_layer=True) is_custom_layer=True)
if op not in self.used_custom_layers: if op not in self.used_custom_layers:
self.used_custom_layers[op] = custom_code self.used_custom_layers[op] = custom_code
if op + '_child_func' not in self.used_custom_layers:
self.used_custom_layers[op + '_child_func'] = child_func_code
def place_holder(self, node): def place_holder(self, node):
self.input_shapes.append(node.out_shapes[0]) self.input_shapes.append(node.out_shapes[0])
...@@ -203,8 +209,8 @@ class ONNXOpMapper(OpMapper): ...@@ -203,8 +209,8 @@ class ONNXOpMapper(OpMapper):
return [0] * ndims, val_padded return [0] * ndims, val_padded
def _interpolate(self, node): def _interpolate(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scales = self.graph.get_node(node.layer.input[1], copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape_ = val_y.out_shapes[0] out_shape_ = val_y.out_shapes[0]
...@@ -245,7 +251,7 @@ class ONNXOpMapper(OpMapper): ...@@ -245,7 +251,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Pad(self, node, op_independent=True): def Pad(self, node, op_independent=True):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
pads = node.get_attr('pads') pads = node.get_attr('pads')
mode = node.get_attr('mode', 'constant') mode = node.get_attr('mode', 'constant')
value = node.get_attr('value', 0.) value = node.get_attr('value', 0.)
...@@ -291,8 +297,18 @@ class ONNXOpMapper(OpMapper): ...@@ -291,8 +297,18 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
return node.layer_name + '_paded' return node.layer_name + '_paded'
def TopK(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
k = 10
attr = {'k': k, 'name': string(node.layer_name)}
node.fluid_code.add_layer('topk',
inputs=val_x,
output=node,
param_attr=attr)
def Unsqueeze(self, node): def Unsqueeze(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
attr = {'axes': axes, 'name': string(node.layer_name)} attr = {'axes': axes, 'name': string(node.layer_name)}
node.fluid_code.add_layer('unsqueeze', node.fluid_code.add_layer('unsqueeze',
...@@ -301,7 +317,7 @@ class ONNXOpMapper(OpMapper): ...@@ -301,7 +317,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Shrink(self, node): def Shrink(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
bias = node.get_attr('bias') bias = node.get_attr('bias')
lambd = node.get_attr('lambd') lambd = node.get_attr('lambd')
assert bias == 0.0, 'not support bias!=0' assert bias == 0.0, 'not support bias!=0'
...@@ -358,8 +374,8 @@ class ONNXOpMapper(OpMapper): ...@@ -358,8 +374,8 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Resize(self, node): def Resize(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scales = self.graph.get_node(node.layer.input[1], copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape_ = val_y.out_shapes[0] out_shape_ = val_y.out_shapes[0]
...@@ -401,24 +417,66 @@ class ONNXOpMapper(OpMapper): ...@@ -401,24 +417,66 @@ class ONNXOpMapper(OpMapper):
def Upsample(self, node): def Upsample(self, node):
self._interpolate(node) self._interpolate(node)
def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0]
axis = node.get_attr('axis')
print(indices.layer_name)
print(indices_shape)
assert len(
indices_shape) == 1, "Gather op don't support dim of indice >1 "
if axis == 0 and len(indices_shape) == 1:
node.fluid_code.add_layer('gather',
inputs=[val_x, indices],
output=node,
param_attr=None)
elif axis > 0 and len(indices_shape) == 1:
perm = [range(len(indices_shape))]
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer('transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer('gather',
inputs=[name_trans, indices],
output=node,
param_attr=None)
node.fluid_code.add_layer('transpose',
inputs=node,
output=node,
param_attr=attr_trans)
def Slice(self, node): def Slice(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_starts = self.graph.get_input_node(node, idx=1, copy=True)
val_ends = self.graph.get_input_node(node, idx=2, copy=True)
val_axes = self.graph.get_input_node(node, idx=3, copy=True)
val_steps = self.graph.get_input_node(node, idx=4, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
axes = node.get_attr('axes') starts = _const_weight_or_none(val_starts).copy()
starts = node.get_attr('starts') ends = _const_weight_or_none(val_ends).copy()
ends = node.get_attr('ends') axes = _const_weight_or_none(val_axes)
steps = _const_weight_or_none(val_steps)
self.omit_nodes.append(val_starts.layer_name)
self.omit_nodes.append(val_ends.layer_name)
self.omit_nodes.append(val_axes.layer_name)
self.omit_nodes.append(val_steps.layer_name)
shape = val_x.out_shapes[0] shape = val_x.out_shapes[0]
if shape is not None: if shape is not None:
for idx, value in enumerate(starts): for idx, value in enumerate(starts):
if value > 2**63 - 1 // 2: if value > shape[axes[idx]]:
value = value - ONNX_INT_MAX starts[idx] = shape[axes[idx]]
starts[idx] = shape[axes[idx]] + value
for idx, value in enumerate(ends): for idx, value in enumerate(ends):
if value > 2**63 - 1 // 2: if value > shape[axes[idx]]:
value = value - ONNX_INT_MAX ends[idx] = shape[axes[idx]]
ends[idx] = shape[axes[idx]] + value
attr = {"axes": axes, "starts": starts, "ends": ends} attr = {"axes": axes, "starts": starts, "ends": ends}
node.fluid_code.add_layer('slice', node.fluid_code.add_layer('slice',
inputs=val_x, inputs=val_x,
...@@ -426,7 +484,7 @@ class ONNXOpMapper(OpMapper): ...@@ -426,7 +484,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def ConstantOfShape(self, node): def ConstantOfShape(self, node):
val_shape = self.graph.get_node(node.layer.input[0], copy=True) val_shape = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
shape = _const_weight_or_none(val_shape) shape = _const_weight_or_none(val_shape)
...@@ -452,7 +510,7 @@ class ONNXOpMapper(OpMapper): ...@@ -452,7 +510,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Split(self, node): def Split(self, node):
val_input = self.graph.get_node(node.layer.input[0], copy=True) val_input = self.graph.get_input_node(node, idx=0, copy=True)
var_outs = [val for val in node.layer.input] var_outs = [val for val in node.layer.input]
fluid_op = 'split' fluid_op = 'split'
...@@ -466,10 +524,11 @@ class ONNXOpMapper(OpMapper): ...@@ -466,10 +524,11 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Reshape(self, node): def Reshape(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_node(node.layer.input[1], copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_reshaped = self.graph.get_node(node.layer.output[0], copy=True) val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
shape = None shape = None
if isinstance(val_shape, ONNXGraphDataNode): if isinstance(val_shape, ONNXGraphDataNode):
self.omit_nodes.append(val_shape.layer_name) self.omit_nodes.append(val_shape.layer_name)
...@@ -503,7 +562,7 @@ class ONNXOpMapper(OpMapper): ...@@ -503,7 +562,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Cast(self, node): def Cast(self, node):
val_input = self.graph.get_node(node.layer.input[0], copy=True) val_input = self.graph.get_input_node(node, idx=0, copy=True)
val_output = self.graph.get_node(node.layer.output[0], copy=True) val_output = self.graph.get_node(node.layer.output[0], copy=True)
dtype = node.get_attr('to') dtype = node.get_attr('to')
...@@ -520,7 +579,7 @@ class ONNXOpMapper(OpMapper): ...@@ -520,7 +579,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def AveragePool(self, node): def AveragePool(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET') auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr("kernel_shape") kernel_shape = node.get_attr("kernel_shape")
...@@ -532,10 +591,10 @@ class ONNXOpMapper(OpMapper): ...@@ -532,10 +591,10 @@ class ONNXOpMapper(OpMapper):
fluid_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
input_shape = val_x.out_shapes[0]
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0]
pad_h = get_same_padding(input_shape[2], kernel_shape[0], pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0]) strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1], pad_w = get_same_padding(input_shape[3], kernel_shape[1],
...@@ -560,7 +619,7 @@ class ONNXOpMapper(OpMapper): ...@@ -560,7 +619,7 @@ class ONNXOpMapper(OpMapper):
def Concat(self, node): def Concat(self, node):
inputs = [] inputs = []
for i in range(len(node.layer.input)): for i in range(len(node.layer.input)):
ipt = self.graph.get_node(node.layer.input[i], copy=True) ipt = self.graph.get_input_node(node, idx=i, copy=True)
if isinstance(ipt, str): if isinstance(ipt, str):
inputs.append(ipt) inputs.append(ipt)
else: else:
...@@ -568,12 +627,12 @@ class ONNXOpMapper(OpMapper): ...@@ -568,12 +627,12 @@ class ONNXOpMapper(OpMapper):
axis = node.get_attr('axis') axis = node.get_attr('axis')
attr = {'axis': axis} attr = {'axis': axis}
node.fluid_code.add_layer('concat', node.fluid_code.add_layer('concat',
inputs='[' + ', '.join(inputs) + ']', inputs=inputs,
output=node, output=node,
param_attr=attr) param_attr=attr)
def Flatten(self, node): def Flatten(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis', 1) axis = node.get_attr('axis', 1)
attr = {"axis": str(axis), "name": string(node.layer_name)} attr = {"axis": str(axis), "name": string(node.layer_name)}
node.fluid_code.add_layer('flatten', node.fluid_code.add_layer('flatten',
...@@ -582,9 +641,9 @@ class ONNXOpMapper(OpMapper): ...@@ -582,9 +641,9 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Gemm(self, node): def Gemm(self, node):
val_a = self.graph.get_node(node.layer.input[0], copy=True) val_a = self.graph.get_input_node(node, idx=0, copy=True)
val_b = self.graph.get_node(node.layer.input[1], copy=True) val_b = self.graph.get_input_node(node, idx=1, copy=True)
val_c = self.graph.get_node(node.layer.input[2], copy=True) val_c = self.graph.get_input_node(node, idx=2, copy=True)
alpha = node.get_attr('alpha', 1.) # optional alpha = node.get_attr('alpha', 1.) # optional
beta = node.get_attr('beta', 1.) # optional beta = node.get_attr('beta', 1.) # optional
...@@ -627,8 +686,8 @@ class ONNXOpMapper(OpMapper): ...@@ -627,8 +686,8 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Add(self, node): def Add(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.input[1], copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
inputs = { inputs = {
"x": val_x, "x": val_x,
"y": val_y, "y": val_y,
...@@ -642,23 +701,24 @@ class ONNXOpMapper(OpMapper): ...@@ -642,23 +701,24 @@ class ONNXOpMapper(OpMapper):
def Sum(self, node): def Sum(self, node):
val_inps = node.layer.input val_inps = node.layer.input
inputs = { inputs = {
"x": val_inps[0], "x": self.graph.get_input_node(node, idx=0, copy=True),
"y": val_inps[1], "y": self.graph.get_input_node(node, idx=1, copy=True),
} }
node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node) node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node)
for ipt in val_inps[2:]: for idx, ipt in enumerate(val_inps[2:]):
y = self.graph.get_input_node(node, idx=idx, copy=True)
inputs = { inputs = {
"x": node.layer_name, "x": node.layer_name,
"y": ipt, "y": y,
} }
node.fluid_code.add_layer("elementwise_add", node.fluid_code.add_layer("elementwise_add",
inputs=inputs, inputs=inputs,
output=node) output=node)
def MatMul(self, node): def MatMul(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.input[1], copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
inputs = {"x": val_x, "y": val_y} inputs = {"x": val_x, "y": val_y}
attr = {"name": string(node.layer_name)} attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer("matmul", node.fluid_code.add_layer("matmul",
...@@ -667,11 +727,11 @@ class ONNXOpMapper(OpMapper): ...@@ -667,11 +727,11 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def BatchNormalization(self, node): def BatchNormalization(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scale = self.graph.get_node(node.layer.input[1], copy=True) val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_node(node.layer.input[2], copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
val_mean = self.graph.get_node(node.layer.input[3], copy=True) val_mean = self.graph.get_input_node(node, idx=3, copy=True)
val_var = self.graph.get_node(node.layer.input[4], copy=True) val_var = self.graph.get_input_node(node, idx=4, copy=True)
self.omit_nodes.append(val_scale.layer_name) self.omit_nodes.append(val_scale.layer_name)
self.omit_nodes.append(val_b.layer_name) self.omit_nodes.append(val_b.layer_name)
...@@ -701,7 +761,7 @@ class ONNXOpMapper(OpMapper): ...@@ -701,7 +761,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Transpose(self, node): def Transpose(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
perm = node.get_attr('perm') perm = node.get_attr('perm')
attr = {'perm': perm, "name": string(node.layer_name)} attr = {'perm': perm, "name": string(node.layer_name)}
node.fluid_code.add_layer("transpose", node.fluid_code.add_layer("transpose",
...@@ -710,12 +770,9 @@ class ONNXOpMapper(OpMapper): ...@@ -710,12 +770,9 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Mul(self, node): def Mul(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.input[1], copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_x_shape = val_x.out_shapes[0]
val_y_shape = val_y.out_shapes[0] val_y_shape = val_y.out_shapes[0]
slice_idx = 0 slice_idx = 0
for dim in val_y_shape: for dim in val_y_shape:
if dim == 1: if dim == 1:
...@@ -747,12 +804,9 @@ class ONNXOpMapper(OpMapper): ...@@ -747,12 +804,9 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Div(self, node): def Div(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.input[1], copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_x_shape = val_x.out_shapes[0]
val_y_shape = val_y.out_shapes[0] val_y_shape = val_y.out_shapes[0]
slice_idx = 0 slice_idx = 0
for dim in val_y_shape: for dim in val_y_shape:
if dim == 1: if dim == 1:
...@@ -784,7 +838,7 @@ class ONNXOpMapper(OpMapper): ...@@ -784,7 +838,7 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Relu(self, node): def Relu(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
attr = {"name": string(node.layer_name)} attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer("relu", node.fluid_code.add_layer("relu",
inputs=val_x, inputs=val_x,
...@@ -792,8 +846,8 @@ class ONNXOpMapper(OpMapper): ...@@ -792,8 +846,8 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def PRelu(self, node): def PRelu(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_node(node.layer.input[1], copy=True) val_slope = self.graph.get_input_node(node, idx=1, copy=True)
mode = 'channel' mode = 'channel'
shape_slope = val_slope.out_shapes[0] shape_slope = val_slope.out_shapes[0]
...@@ -811,20 +865,20 @@ class ONNXOpMapper(OpMapper): ...@@ -811,20 +865,20 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Squeeze(self, node): def Squeeze(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
squeeze_dims = node.get_attr('squeeze_dims') axes = node.get_attr('axes')
attr = {'axes': squeeze_dims, "name": string(node.layer_name)} attr = {'axes': axes, "name": string(node.layer_name)}
node.fluid_code.add_layer("squeeze", node.fluid_code.add_layer("squeeze",
inputs=val_x, inputs=val_x,
output=node, output=node,
param_attr=attr) param_attr=attr)
def Identity(self, node): def Identity(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer("assign", inputs=val_x, output=node) node.fluid_code.add_layer("assign", inputs=val_x, output=node)
def MaxPool(self, node): def MaxPool(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET') auto_pad = node.get_attr('auto_pad', 'NOTSET')
assert node.get_attr( assert node.get_attr(
...@@ -839,10 +893,10 @@ class ONNXOpMapper(OpMapper): ...@@ -839,10 +893,10 @@ class ONNXOpMapper(OpMapper):
fluid_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
input_shape = val_x.out_shapes[0]
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0]
pad_h = get_same_padding(input_shape[2], kernel_shape[0], pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0]) strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1], pad_w = get_same_padding(input_shape[3], kernel_shape[1],
...@@ -863,8 +917,18 @@ class ONNXOpMapper(OpMapper): ...@@ -863,8 +917,18 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
# def Tile(self, node):
# pass
# def Loop(self, node):
# pass
# def NonMaxSuppression(self, node):
# pass
def GlobalAveragePool(self, node): def GlobalAveragePool(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
output_shape = val_y.out_shapes[0] output_shape = val_y.out_shapes[0]
...@@ -886,21 +950,19 @@ class ONNXOpMapper(OpMapper): ...@@ -886,21 +950,19 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def Conv(self, node): def Conv(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_node(node.layer.input[1], copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
self.omit_nodes.append(val_w.layer_name) self.omit_nodes.append(val_w.layer_name)
has_bias = len(node.layer.input) == 3 has_bias = len(node.layer.input) == 3
if has_bias: if has_bias:
val_b = self.graph.get_node(node.layer.input[2], copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name) self.omit_nodes.append(val_b.layer_name)
auto_pad = node.get_attr('auto_pad', 'NOTSET') auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = val_w.out_shapes[0][2:] # OI... kernel_shape = node.get_attr('kernel_shape')
assert kernel_shape == node.get_attr(
'kernel_shape'), 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported' assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = val_w.out_shapes[0][0] # OI... num_out_channels = val_w.out_shapes[0][0] # OI...
...@@ -941,9 +1003,9 @@ class ONNXOpMapper(OpMapper): ...@@ -941,9 +1003,9 @@ class ONNXOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def ConvTranspose(self, node): def ConvTranspose(self, node):
val_x = self.graph.get_node(node.layer.input[0], copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_node(node.layer.input[1], copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_node(node.layer.input[2], copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_w.layer_name) self.omit_nodes.append(val_w.layer_name)
self.omit_nodes.append(val_b.layer_name) self.omit_nodes.append(val_b.layer_name)
...@@ -952,7 +1014,7 @@ class ONNXOpMapper(OpMapper): ...@@ -952,7 +1014,7 @@ class ONNXOpMapper(OpMapper):
auto_pad = node.get_attr('auto_pad', 'NOTSET') auto_pad = node.get_attr('auto_pad', 'NOTSET')
out_padding = node.get_attr('output_padding', [0, 0]) out_padding = node.get_attr('output_padding', [0, 0])
kernel_shape = node.get_attr('kernel_shape', val_w.out_shapes[0][2:]) kernel_shape = node.get_attr('kernel_shape')
assert kernel_shape, 'kernel_shape not inferred' assert kernel_shape, 'kernel_shape not inferred'
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported' assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册