未验证 提交 0933e73f 编写于 作者: J Jason 提交者: GitHub

Merge pull request #273 from Channingss/develop_paddle1.8

support paddle1.8, remove onnxruntime
......@@ -11,3 +11,6 @@ x2paddle -f tensorflow -m tf.pb -s pd-model --without_data_format_optimization -
```
> 1. 目前Tensorflow的CV模型大部分均为`NHWC`的输入格式,而Paddle的默认输入格式为`NCHW`,因此X2Paddle在转换过程中,会对如`axis`, `shape`等参数进行转换,适应Paddle的NCHW格式。但在这种情况下,可能会由于TensorFlow模型太复杂,导致出错。 指定`--without_data_format_optimization`后,会停止对`axis`,`shape`等参数的优化(这可能会带来一定数量的transpose操作)
**Q3. ONNX模型转换过程中,提示『Unknown shape for input tensor[tensor name: "input"] -> shape: ['batch', 'sequence'], Please define shape of input here』**
A:该提示信息表示从ONNX的模型中获取到输入tensor(tensor名为"input:)的shape是语义象征性的['batch', 'sequence'],而不是dim为int类型的shape,从而可能会因为部分node的shape无法推理,导致转换失败。所以用户可以尝试手动在提示后输入详细的shape信息,如:-1,3,224,224 其中-1表示Batch
......@@ -10,12 +10,12 @@ X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX模型的
## 环境依赖
python == 2.7 | python >= 3.5
paddlepaddle >= 1.6.0
paddlepaddle >= 1.8.0
**按需安装以下依赖**
tensorflow : tensorflow == 1.14.0
caffe : 无
onnx : onnx == 1.6.0 onnxruntime == 1.0.0
onnx : onnx == 1.6.0
## 安装
### 安装方式一(推荐)
......@@ -63,6 +63,7 @@ x2paddle --framework=paddle2onnx --model=paddle_infer_model_dir --save_dir=onnx_
|--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ |
## 使用转换后的模型
转换后的模型包括`model_with_code``inference_model`两个目录。
`model_with_code`中保存了模型参数,和转换后的python模型代码
......
......@@ -175,13 +175,15 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
import onnxruntime
model = ONNXDecoder(model_path)
mapper = ONNXOpMapper(model, save_dir)
mapper = ONNXOpMapper(model)
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
print("Model optimized.")
optimizer.delete_redundance_code()
print("Paddle model and code generating ...")
mapper.save_inference_model(save_dir, params_merge)
print("Paddle model and code generated.")
def paddle2onnx(model_path, save_dir):
......@@ -211,18 +213,6 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
if args.framework == "onnx":
try:
import onnxruntime as rt
version = rt.__version__
if version != '1.0.0':
print("[ERROR] onnxruntime==1.0.0 is required")
return
except:
print(
"[ERROR] onnxruntime is not installed, use \"pip install onnxruntime==1.0.0\"."
)
try:
import paddle
v0, v1, v2 = paddle.__version__.split('.')
......@@ -261,6 +251,7 @@ def main():
elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model"
params_merge = False
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
......
......@@ -25,6 +25,7 @@ class Layer(object):
self.inputs = dict()
self.output = None
self.is_custom_layer = False
self.use_fluid = False
def get_code(self):
layer_code = ""
......@@ -38,6 +39,8 @@ class Layer(object):
layer_code = layer_code + self.op + "("
elif self.op == "=":
layer_code = layer_code
elif self.use_fluid:
layer_code = layer_code + "fluid." + self.op + "("
else:
layer_code = layer_code + "fluid.layers." + self.op + "("
......@@ -108,9 +111,11 @@ class FluidCode(object):
inputs,
output,
param_attr=None,
use_fluid=False,
is_custom_layer=False):
layer = Layer()
layer.op = op
layer.use_fluid = use_fluid
layer.is_custom_layer = is_custom_layer
if inputs is not None:
layer.inputs = inputs
......
......@@ -29,11 +29,14 @@ def export_paddle_param(param, param_name, dir):
"bool": [framework_pb2.VarType.BOOL, None]
}
shape = param.shape
if str(param.dtype) in ['uint8', 'uint_8', 'bool']:
param = param.astype('int64')
if len(shape) == 0:
assert param.size == 1, "Unexpected situation happend!"
shape = [1]
assert str(param.dtype) in dtype_map, "Unknown dtype of params."
assert str(
param.dtype) in dtype_map, "Unknown dtype {} of params: {}.".format(
str(param.dtype), param_name)
fp = open(os.path.join(dir, param_name), 'wb')
numpy.array([0], dtype='int32').tofile(fp)
numpy.array([0], dtype='int64').tofile(fp)
......
......@@ -14,6 +14,7 @@
from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode
from x2paddle.decoder.onnx_shape_inference import SymbolicShapeInference
from onnx.checker import ValidationError
from onnx.checker import check_model
from onnx.utils import polish_model
......@@ -53,7 +54,7 @@ class ONNXGraphNode(GraphNode):
convert ONNX node attributes to dict
"""
return {
attr.name: self.get_attribute_value2(attr)
attr.name: self.get_attribute_value(attr)
for attr in self.layer.attribute
}
......@@ -64,7 +65,7 @@ class ONNXGraphNode(GraphNode):
return None
return self.attr_map['value']
def get_attribute_value2(self, attr):
def get_attribute_value(self, attr):
"""
get_attribute_value enhanced
"""
......@@ -130,43 +131,90 @@ class ONNXGraphDataNode(GraphNode):
class ONNXGraph(Graph):
def __init__(self, onnx_model):
super(ONNXGraph, self).__init__(onnx_model.graph)
self.onnx_model = onnx_model
super(ONNXGraph, self).__init__(onnx_model)
self.fixed_input_shape = {}
self.initializer = {}
self.place_holder_nodes = list()
self.value_infos = {}
self.graph = onnx_model.graph
self.get_place_holder_nodes()
self.value_infos = self.inferred_model_value_info(self.model)
self.results_of_inference = dict()
print("shape inferencing ...")
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
print("shape inferenced.")
self.build()
self.collect_value_infos()
self.allocate_shapes()
def get_inner_nodes(self):
"""
generate inner node of ONNX model
"""
inner_nodes = []
if not isinstance(self.model, onnx.GraphProto):
if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
return
for initializer in self.model.initializer:
for initializer in self.graph.initializer:
name = initializer.name
inner_nodes.append(name)
return inner_nodes
def get_symbolic_shape(self, dims):
shape = []
for dim in dims:
if dim.HasField('dim_param'):
shape.append(dim.dim_param)
else:
shape.append(dim.dim_value)
return shape
def check_input_shape(self, vi):
if vi.type.HasField('tensor_type'):
for dim in vi.type.tensor_type.shape.dim:
if dim.HasField(
'dim_param') and vi.name not in self.fixed_input_shape:
shape = self.get_symbolic_shape(
vi.type.tensor_type.shape.dim)
print(
"Unknown shape for input tensor[tensor name: '{}'] -> shape: {}, Please define shape of input here,\nNote:you can use visualization tools like Netron to check input shape."
.format(vi.name, shape))
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
except:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
if shape.count("-1") > 1:
print("Only 1 dimension can be -1, type again:)")
else:
right_shape_been_input = True
if shape == 'N':
break
shape = [int(dim) for dim in shape.strip().split(',')]
assert shape.count(-1) <= 1, "Only one dimension can be -1"
self.fixed_input_shape[vi.name] = shape
break
def get_place_holder_nodes(self):
"""
generate place_holder node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
input_nodes = [value.name for value in self.model.input]
for ipt_data in input_nodes:
if ipt_data not in inner_nodes:
self.place_holder_nodes.append(ipt_data)
for ipt_vi in self.graph.input:
if ipt_vi.name not in inner_nodes:
self.check_input_shape(ipt_vi)
self.place_holder_nodes.append(ipt_vi.name)
def get_output_nodes(self):
"""
generate output_nodes node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
output_nodes = [value.name for value in self.model.output]
output_nodes = [value.name for value in self.graph.output]
for opt_data in output_nodes:
if opt_data not in inner_nodes:
self.output_nodes.append(opt_data)
......@@ -183,11 +231,11 @@ class ONNXGraph(Graph):
"""
build topo_sort of ONNX model
"""
for layer in self.model.node:
for layer in self.graph.node:
node = ONNXGraphNode(layer)
self.node_map[layer.name] = node
for layer in self.model.input:
for layer in self.graph.input:
if layer.name not in self.node_map:
is_place_holder = self.is_place_holder_nodes(layer.name)
self.node_map[layer.name] = ONNXGraphDataNode(
......@@ -196,7 +244,7 @@ class ONNXGraph(Graph):
is_global_input=is_place_holder)
#set data node's weight
for initializer in self.model.initializer:
for initializer in self.graph.initializer:
name = initializer.name
weight = to_array(initializer)
if name in self.node_map:
......@@ -228,7 +276,7 @@ class ONNXGraph(Graph):
continue
if in_node not in self.node_map:
flag = 0
for nd in self.model.node:
for nd in self.graph.node:
for idx, opt in enumerate(nd.output):
if opt == in_node:
self.connect(nd.name, layer_name)
......@@ -256,81 +304,68 @@ class ONNXGraph(Graph):
ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node
def graph_weights(self, graph):
def graph_weights(self):
"""
generator for weights
"""
if not isinstance(graph, onnx.GraphProto):
if not isinstance(self.graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
return
for initializer in graph.initializer:
for initializer in self.graph.initializer:
name = initializer.name
weight = to_array(initializer)
yield name, weight
def inferred_model_value_info(self, graph):
def collect_value_infos(self):
"""
collect value/type info for an ONNX model
"""
assert isinstance(graph,
assert isinstance(self.graph,
onnx.GraphProto), 'model is not a ModelProto instance'
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = {
for item in self.graph.value_info:
self.value_infos[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': False
}
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
for item in graph.output:
assert item.name not in value_info
value_info[item.name] = {
'dtype':
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
'shape':
[dim.dim_value for dim in item.type.tensor_type.shape.dim],
'external': True
}
return value_info
def allocate_shapes(self):
"""
run shape inference
"""
for layer in self.graph.node:
node = self.node_map[layer.name]
for opt in layer.output:
if opt in self.value_infos:
value_info = self.value_infos[opt]
#if len(value_info['shape']) == 0 or value_info[
# 'dtype'] is None or 0 in value_info['shape']:
# #TODO add node shape inference
node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else:
node.out_shapes.append([])
class ONNXDecoder(object):
def __init__(self, onnx_model):
model = onnx.load(onnx_model)
onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format(
model.ir_version, model.opset_import[0].version))
if model.opset_import[0].version < 9:
_logger.warning(
'Now, onnx2paddle support convert onnx model opset_verison == 9,'
'opset_verison of your onnx model is %d < 9,'
'some operator maybe unsuccessful in convertion.',
model.opset_import[0].version)
check_model(model)
self.check_model_running_state(onnx_model)
model = onnx.shape_inference.infer_shapes(model)
model = self.optimize_model_skip_op_for_inference(model)
model = self.optimize_model_strip_initializer(model)
self.standardize_variable_name(model.graph)
self.model = model
graph = model.graph
self.onnx_graph = ONNXGraph(model)
self.onnx_graph.build()
onnx_model.ir_version, onnx_model.opset_import[0].version))
self.op_set = onnx_model.opset_import[0].version
check_model(onnx_model)
onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_model_strip_initializer(onnx_model)
onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model)
#self.onnx_model = onnx_model
def build_value_refs(self, nodes):
"""
......@@ -373,14 +408,13 @@ class ONNXDecoder(object):
processed += 1
return processed
def optimize_model_skip_op_for_inference(self, model, op_list=None):
def optimize_model_skip_op(self, model, op_list=None):
"""
skip ops can be bypassed for inference
"""
nodes = model.graph.node
if op_list is None:
op_list = ['Dropout']
nodes = model.graph.node
input_refs, output_refs = self.build_value_refs(nodes)
ret = type(model)()
ret.CopyFrom(model)
......@@ -473,38 +507,11 @@ class ONNXDecoder(object):
name = name.replace(s, '_')
return 'x2paddle_' + name
def check_model_running_state(self, model_path):
import onnxruntime as rt
model = onnx.load(model_path)
model = onnx.shape_inference.infer_shapes(model)
if len(model.graph.value_info) < len(model.graph.node) - 1:
_logger.warning(
'During conversion of your model, some operators will be assignd node.out_shape==None, '
'refer to https://github.com/onnx/onnx/blob/master/docs/ShapeInference.md'
)
try:
datatype_map = {
'tensor(int64)': 'int',
'tensor(float)': 'float32',
'tensor(int32)': 'int32'
}
input_dict = {}
sess = rt.InferenceSession(model_path)
for ipt in sess.get_inputs():
datatype = datatype_map[ipt.type]
input_dict[ipt.name] = np.random.random(ipt.shape).astype(
datatype)
res = sess.run(None, input_feed=input_dict)
except:
raise Exception(
"onnxruntime inference onnx model failed, Please confirm the correctness of onnx model by onnxruntime, if onnx model is correct, please submit issue in github."
)
def standardize_variable_name(self, graph):
def optimize_node_name(self, model):
"""
standardize variable name for paddle's code
"""
graph = model.graph
for initializer in graph.initializer:
initializer.name = self.make_variable_name(initializer.name)
for ipt in graph.input:
......@@ -523,3 +530,4 @@ class ONNXDecoder(object):
node.input[i] = self.make_variable_name(node.input[i])
for i in range(len(node.output)):
node.output[i] = self.make_variable_name(node.output[i])
return model
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Reference Code from https://github.com/microsoft/onnxruntime, Licensed under the MIT License.
# -*- coding: UTF-8 -*-
import argparse
import numpy as np
import onnx
import sys
from onnx import helper, numpy_helper, shape_inference
import sympy
from packaging import version
assert version.parse(onnx.__version__) >= version.parse("1.5.0")
def get_attribute(node, attr_name, default_value=None):
found = [attr for attr in node.attribute if attr.name == attr_name]
if found:
return helper.get_attribute_value(found[0])
return default_value
def get_dim_from_type_proto(dim):
return getattr(dim, dim.WhichOneof('value')) if type(
dim.WhichOneof('value')) == str else None
def get_shape_from_type_proto(type_proto):
return [
get_dim_from_type_proto(d) for d in type_proto.tensor_type.shape.dim
]
def get_shape_from_sympy_shape(sympy_shape):
sympy_shape = [
None if i is None else (int(i) if is_literal(i) else str(i))
for i in sympy_shape
]
return sympy_shape
def is_literal(dim):
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (
hasattr(dim, 'is_number') and
dim.is_number) # or (hasattr(dim, 'is_integer') and dim.is_integer)
def handle_negative_axis(axis, rank):
assert axis < rank and axis >= -rank
return axis if axis >= 0 else rank + axis
def get_opset(mp, domain=['', 'onnx', 'ai.onnx']):
if type(domain) != list:
domain = [domain]
for opset in mp.opset_import:
if opset.domain in domain:
return opset.version
return None
def as_scalar(x):
if type(x) == list:
assert len(x) == 1
return x[0]
elif type(x) == np.ndarray:
return np.asscalar(x)
else:
return x
def as_list(x, keep_none):
if type(x) == list:
return x
elif type(x) == np.ndarray:
return list(x)
elif keep_none and x is None:
return None
else:
return [x]
def sympy_reduce_product(x):
if type(x) == list:
value = sympy.Integer(1)
for v in x:
value = value * v
else:
value = x
return value
class SymbolicShapeInference:
def __init__(self, int_max, auto_merge, guess_output_rank, verbose):
self.dispatcher_ = {
'Add': self._infer_symbolic_compute_ops,
'ArrayFeatureExtractor': self._infer_ArrayFeatureExtractor,
'AveragePool': self._infer_Pool,
'Cast': self._infer_Cast,
'CategoryMapper': self._infer_CategoryMapper,
'Compress': self._infer_Compress,
'Concat': self._infer_Concat,
'ConstantOfShape': self._infer_ConstantOfShape,
'Conv': self._infer_Conv,
'CumSum': self._pass_on_shape_and_type,
'Div': self._infer_symbolic_compute_ops,
'Expand': self._infer_Expand,
'Equal': self._infer_symbolic_compute_ops,
'Gather': self._infer_Gather,
'GatherElements': self._infer_GatherElements,
'GatherND': self._infer_GatherND,
'If': self._infer_If,
'Loop': self._infer_Loop,
'MatMul': self._infer_MatMul,
'MatMulInteger16': self._infer_MatMulInteger,
'MaxPool': self._infer_Pool,
'Max': self._infer_symbolic_compute_ops,
'Min': self._infer_symbolic_compute_ops,
'Mul': self._infer_symbolic_compute_ops,
'NonMaxSuppression': self._infer_NonMaxSuppression,
'NonZero': self._infer_NonZero,
'OneHot': self._infer_OneHot,
'Pad': self._infer_Pad,
'Range': self._infer_Range,
'ReduceProd': self._infer_ReduceProd,
'Reshape': self._infer_Reshape,
'Resize': self._infer_Resize,
'Round': self._pass_on_shape_and_type,
'Scan': self._infer_Scan,
'ScatterElements': self._infer_ScatterElements,
'Shape': self._infer_Shape,
'Size': self._infer_Size,
'Slice': self._infer_Slice,
'Split': self._infer_Split,
'Squeeze': self._infer_Squeeze,
'Sub': self._infer_symbolic_compute_ops,
'Tile': self._infer_Tile,
'TopK': self._infer_TopK,
'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops,
'Transpose': self._infer_Transpose,
'ZipMap': self._infer_ZipMap
}
self.run_ = True
self.suggested_merge_ = {}
self.symbolic_dims_ = {}
self.input_symbols_ = {}
self.auto_merge_ = auto_merge
self.guess_output_rank_ = guess_output_rank
self.verbose_ = verbose
self.int_max_ = int_max
def _add_suggested_merge(self, symbols, apply=False):
assert all([(type(s) == str and s in self.symbolic_dims_) or
is_literal(s) for s in symbols])
symbols = set(symbols)
for k, v in self.suggested_merge_.items():
if k in symbols:
symbols.remove(k)
symbols.add(v)
map_to = None
# if there is literal, map to it first
for s in symbols:
if is_literal(s):
map_to = s
break
# when no literals, map to input symbolic dims, then existing symbolic dims
if map_to is None:
for s in symbols:
if s in self.input_symbols_:
map_to = s
break
if map_to is None:
for s in symbols:
if type(self.symbolic_dims_[s]) == sympy.Symbol:
map_to = s
break
# when nothing to map to, use the shorter one
if map_to is None:
if self.verbose_ > 0:
print(
'Potential unsafe merge between symbolic expressions: ({})'.
format(','.join(symbols)))
symbols_list = list(symbols)
lens = [len(s) for s in symbols_list]
map_to = symbols_list[lens.index(min(lens))]
symbols.remove(map_to)
for s in symbols:
if s == map_to:
continue
if is_literal(map_to) and is_literal(s):
assert int(map_to) == int(s)
self.suggested_merge_[s] = int(map_to) if is_literal(
map_to) else map_to
for k, v in self.suggested_merge_.items():
if v == s:
self.suggested_merge_[k] = map_to
if apply and self.auto_merge_:
self._apply_suggested_merge()
def _apply_suggested_merge(self, graph_input_only=False):
if not self.suggested_merge_:
return
for i in list(self.out_mp_.graph.input) + (
[] if graph_input_only else list(self.out_mp_.graph.value_info)):
for d in i.type.tensor_type.shape.dim:
if d.dim_param in self.suggested_merge_:
v = self.suggested_merge_[d.dim_param]
if is_literal(v):
d.dim_value = int(v)
else:
d.dim_param = v
def _preprocess(self, in_mp, input_shapes=None):
out_mp = onnx.ModelProto()
out_mp.CopyFrom(in_mp)
out_mp.graph.ClearField('node')
self.out_mp_ = out_mp
defined = set([
i.name
for i in list(in_mp.graph.input) + list(in_mp.graph.initializer)
])
pending_nodes = []
# returns True if no more ready nodes
def _insert_ready_nodes():
ready_nodes = [
pn for pn in pending_nodes
if all([i in defined for i in pn.input if i])
]
for rn in ready_nodes:
self.out_mp_.graph.node.add().CopyFrom(rn)
for o in rn.output:
defined.add(o)
pending_nodes.remove(rn)
return not ready_nodes
# constant op -> initializer, topological sort
for in_n in in_mp.graph.node:
if in_n.op_type == 'Constant':
t = get_attribute(in_n, 'value')
t.name = in_n.output[0]
self.out_mp_.graph.initializer.add().CopyFrom(t)
defined.add(t.name)
else:
pending_nodes.append(in_n)
_insert_ready_nodes()
while pending_nodes:
if _insert_ready_nodes():
break
if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ')
print(
* [n.op_type + ': ' + n.output[0] for n in pending_nodes],
sep='\n')
if input_shapes is not None:
for input_name, shape in input_shapes.items():
for idx in range(len(self.out_mp_.graph.input)):
if self.out_mp_.graph.input[idx].name == input_name:
value_info = self.out_mp_.graph.input[idx]
del self.out_mp_.graph.input[idx]
self.out_mp_.graph.input.append(
helper.make_tensor_value_info(
value_info.name,
value_info.type.tensor_type.elem_type, shape))
self.initializers_ = dict(
[(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict(
[(i.name, i) for i in list(self.out_mp_.graph.input)])
self.known_vi_.update(
dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type,
list(i.dims)))
for i in self.out_mp_.graph.initializer]))
def _merge_symbols(self, dims):
if not all([type(d) == str for d in dims]):
if self.auto_merge_:
assert len(
dims
) == 2 # only allow symbol->int merge in binary ops for now
is_int = [is_literal(d) for d in dims]
if sum(is_int) == 1:
int_dim = is_int.index(1)
if self.verbose_ > 0:
print('dim {} has been merged with value {}'.format(
dims[1 - int_dim], dims[int_dim]))
self._check_merged_dims(dims, allow_broadcast=False)
return dims[int_dim]
else:
if self.verbose_ > 0:
print('dim {} has been mergd with dim {}'.format(dims[
0], dims[1]))
return dims[0]
else:
return None
if all([d == dims[0] for d in dims]):
return dims[0]
merged = [
self.suggested_merge_[d] if d in self.suggested_merge_ else d
for d in dims
]
if all([d == merged[0] for d in merged]):
assert merged[0] in self.symbolic_dims_
return merged[0]
else:
return None
# broadcast from right to left, and merge symbolic dims if needed
def _broadcast_shapes(self, shape1, shape2):
new_shape = []
rank1 = len(shape1)
rank2 = len(shape2)
new_rank = max(rank1, rank2)
for i in range(new_rank):
dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
if dim1 == 1 or dim1 == dim2:
new_dim = dim2
elif dim2 == 1:
new_dim = dim1
else:
new_dim = self._merge_symbols([dim1, dim2])
if not new_dim:
# warning about unsupported broadcast when not auto merge
# note that auto merge has the risk of incorrectly merge symbols while one of them being 1
# for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
if self.auto_merge_:
self._add_suggested_merge([dim1, dim2], apply=True)
else:
print('unsupported broadcast between ' + str(dim1) + ' '
+ str(dim2))
new_shape = [new_dim] + new_shape
return new_shape
def _get_shape(self, node, idx):
name = node.input[idx]
shape = []
if name in self.known_vi_:
shape = get_shape_from_type_proto(self.known_vi_[name].type)
elif name in self.initializers_:
assert name in self.initializers_
shape = list(self.initializers_[name].dims)
return shape
def _get_initializer_value(self, node, idx):
name = node.input[idx]
if name in self.initializers_:
value = numpy_helper.to_array(self.initializers_[name])
return value
else:
return False
def _get_shape_rank(self, node, idx):
return len(self._get_shape(node, idx))
def _get_sympy_shape(self, node, idx):
sympy_shape = []
for d in self._get_shape(node, idx):
if type(d) is str:
sympy_shape.append(self.symbolic_dims_[d] if d in
self.symbolic_dims_ else sympy.Symbol(
d, integer=True))
else:
assert None != d
sympy_shape.append(d)
return sympy_shape
def _get_value(self, node, idx):
name = node.input[idx]
assert name in self.sympy_data_ or name in self.initializers_
return self.sympy_data_[
name] if name in self.sympy_data_ else numpy_helper.to_array(
self.initializers_[name])
def _try_get_value(self, node, idx):
if idx >= len(node.input):
return None
name = node.input[idx]
if name in self.sympy_data_ or name in self.initializers_:
return self._get_value(node, idx)
return None
def _update_computed_dims(self, new_sympy_shape):
for i, new_dim in enumerate(new_sympy_shape):
if not is_literal(new_dim) and not type(new_dim) == str:
str_dim = str(new_dim)
if str_dim in self.suggested_merge_:
new_sympy_shape[i] = self.symbolic_dims_[
self.suggested_merge_[str_dim]]
else:
# add new_dim if it's a computational expression
if not str(new_dim) in self.symbolic_dims_:
self.symbolic_dims_[str(new_dim)] = new_dim
def _onnx_infer_single_node(self, node):
# skip onnx shape inference for Scan/Loop
skip_infer = node.op_type in ['Scan', 'Loop']
if not skip_infer:
# run single node inference with self.known_vi_ shapes
# note that inference rely on initializer values is not handled
# as we don't copy initializer weights to tmp_graph for inference speed purpose
tmp_graph = helper.make_graph(
[node], 'tmp', [self.known_vi_[i] for i in node.input if i], [
helper.make_tensor_value_info(i, onnx.TensorProto.UNDEFINED,
None) for i in node.output
])
self.tmp_mp_.graph.CopyFrom(tmp_graph)
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
for i_o in range(len(node.output)):
o = node.output[i_o]
vi = self.out_mp_.graph.value_info.add()
if not skip_infer:
vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
self.known_vi_[o] = vi
def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True):
if self.verbose_ > 2:
print('Inferencing subgraph of node {} with output({}...): {}'.
format(node.name, node.output[0], node.op_type))
# node inputs are not passed directly to the subgraph
# it's up to the node dispatcher to prepare subgraph input
# for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
# besides, inputs in subgraph could shadow implicit inputs
subgraph_inputs = set([
i.name for i in list(subgraph.initializer) + list(subgraph.input)
])
subgraph_implicit_input = set([
name for name in self.known_vi_.keys()
if not name in subgraph_inputs
])
tmp_graph = helper.make_graph(
list(subgraph.node), 'tmp',
list(subgraph.input) +
[self.known_vi_[i] for i in subgraph_implicit_input], [
helper.make_tensor_value_info(i.name,
onnx.TensorProto.UNDEFINED, None)
for i in subgraph.output
])
tmp_graph.initializer.extend([
i for i in self.out_mp_.graph.initializer
if i.name in subgraph_implicit_input
])
tmp_graph.initializer.extend(subgraph.initializer)
self.tmp_mp_.graph.CopyFrom(tmp_graph)
symbolic_shape_inference = SymbolicShapeInference(
self.int_max_, self.auto_merge_, self.guess_output_rank_,
self.verbose_)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(self.tmp_mp_)
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(
self.tmp_mp_, self.sympy_data_.copy())
symbolic_shape_inference._update_output_from_vi()
if use_node_input:
# if subgraph uses node input, it needs to update to merged dims
subgraph.ClearField('input')
subgraph.input.extend(
symbolic_shape_inference.out_mp_.graph.input[:len(node.input)])
subgraph.ClearField('output')
subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
subgraph.ClearField('value_info')
subgraph.value_info.extend(
symbolic_shape_inference.out_mp_.graph.value_info)
subgraph.ClearField('node')
subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
# for new symbolic dims from subgraph output, add to main graph symbolic dims
subgraph_shapes = [
get_shape_from_type_proto(o.type)
for o in symbolic_shape_inference.out_mp_.graph.output
]
subgraph_new_symbolic_dims = set([
d for s in subgraph_shapes
if s for d in s if type(d) == str and not d in self.symbolic_dims_
])
new_dims = {}
for d in subgraph_new_symbolic_dims:
assert d in symbolic_shape_inference.symbolic_dims_
new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
self.symbolic_dims_.update(new_dims)
return symbolic_shape_inference
def _get_int_values(self, node, broadcast=False):
values = [self._try_get_value(node, i) for i in range(len(node.input))]
if all([v is not None for v in values]):
# some shape compute is in floating point, cast to int for sympy
for i, v in enumerate(values):
if type(v) != np.ndarray:
continue
if len(v.shape) > 1:
new_v = None # ignore value for rank > 1
elif len(v.shape) == 0:
new_v = int(np.asscalar(v))
else:
assert len(v.shape) == 1
new_v = [int(vv) for vv in v]
values[i] = new_v
values_len = [len(v) if type(v) == list else 0 for v in values]
max_len = max(values_len)
if max_len >= 1 and broadcast:
# broadcast
for i, v in enumerate(values):
if v is None:
continue # don't broadcast if value is unknown
if type(v) == list:
if len(v) < max_len:
values[i] = v * max_len
else:
assert len(v) == max_len
else:
values[i] = [v] * max_len
return values
def _compute_on_sympy_data(self, node, op_func):
assert len(node.output) == 1
values = self._get_int_values(node, broadcast=True)
if all([v is not None for v in values]):
new_shape = []
is_list = [type(v) == list for v in values]
as_list = any(is_list)
if as_list:
data = [op_func(vs) for vs in zip(*values)]
self.sympy_data_[node.output[0]] = data
new_shape = np.array(data).shape
else:
data = op_func(values)
self.sympy_data_[node.output[0]] = data
new_shape = np.array(data).shape
vi = self.known_vi_[node.output[0]]
#print(node.output[0])
#print(new_shape)
#vi.CopyFrom(helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, list(new_shape)))
def _pass_on_sympy_data(self, node):
assert len(node.input) == 1 or node.op_type == 'Reshape'
self._compute_on_sympy_data(node, lambda x: x[0])
def _pass_on_shape_and_type(self, node):
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
self._get_shape(node, 0)))
def _new_symbolic_dim(self, prefix, dim):
new_dim = '{}_d{}'.format(prefix, dim)
if new_dim in self.suggested_merge_:
v = self.suggested_merge_[new_dim]
new_dim = sympy.Integer(int(v)) if is_literal(v) else v
else:
self.symbolic_dims_[new_dim] = sympy.Symbol(new_dim, integer=True)
return new_dim
def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
return self._new_symbolic_dim('{}{}_o{}_'.format(
node.op_type, list(self.out_mp_.graph.node).index(node), out_idx),
dim)
def _new_symbolic_shape(self, rank, node, out_idx=0):
return [
self._new_symbolic_dim_from_output(node, out_idx, i)
for i in range(rank)
]
def _compute_conv_pool_shape(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
if len(node.input) > 1:
W_shape = self._get_sympy_shape(node, 1)
rank = len(W_shape) - 2 # number of spatial axes
kernel_shape = W_shape[-rank:]
sympy_shape[1] = W_shape[0]
else:
W_shape = None
kernel_shape = get_attribute(node, 'kernel_shape')
rank = len(kernel_shape)
assert len(sympy_shape) == rank + 2
# only need to symbolic shape inference if input has symbolic dims in spatial axes
is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]]
if not any(is_symbolic_dims):
shape = get_shape_from_type_proto(self.known_vi_[node.output[0]]
.type)
if len(shape) > 0:
assert len(sympy_shape) == len(shape)
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
return sympy_shape
dilations = get_attribute(node, 'dilations', [1] * rank)
strides = get_attribute(node, 'strides', [1] * rank)
effective_kernel_shape = [(k - 1) * d + 1
for k, d in zip(kernel_shape, dilations)]
pads = get_attribute(node, 'pads')
if pads is None:
pads = [0] * (2 * rank)
auto_pad = get_attribute(node, 'auto_pad',
b'NOTSET').decode('utf-8')
if auto_pad != 'VALID' and auto_pad != 'NOTSET':
try:
residual = [
sympy.Mod(d, s)
for d, s in zip(sympy_shape[-rank:], strides)
]
total_pads = [
max(0, (k - s) if r == 0 else (k - r))
for k, s, r in zip(effective_kernel_shape, strides,
residual)
]
except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
total_pads = [
max(0, (k - s))
for k, s in zip(effective_kernel_shape, strides)
] # assuming no residual if sympy throws error
elif auto_pad == 'VALID':
total_pads = []
else:
total_pads = [0] * rank
else:
assert len(pads) == 2 * rank
total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
ceil_mode = get_attribute(node, 'ceil_mode', 0)
for i in range(rank):
effective_input_size = sympy_shape[-rank + i]
if len(total_pads) > 0:
effective_input_size = effective_input_size + total_pads[i]
if ceil_mode:
strided_kernel_positions = sympy.ceiling(
(effective_input_size - effective_kernel_shape[i]) /
strides[i])
else:
strided_kernel_positions = (
effective_input_size - effective_kernel_shape[i]
) // strides[i]
sympy_shape[-rank + i] = strided_kernel_positions + 1
return sympy_shape
def _check_merged_dims(self, dims, allow_broadcast=True):
if allow_broadcast:
dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
if not all([d == dims[0] for d in dims]):
self._add_suggested_merge(dims, apply=True)
def _compute_matmul_shape(self, node, output_dtype=None):
lhs_shape = self._get_shape(node, 0)
rhs_shape = self._get_shape(node, 1)
lhs_rank = len(lhs_shape)
rhs_rank = len(rhs_shape)
lhs_reduce_dim = 0
rhs_reduce_dim = 0
assert lhs_rank > 0 and rhs_rank > 0
if lhs_rank == 1 and rhs_rank == 1:
new_shape = []
elif lhs_rank == 1:
rhs_reduce_dim = -2
new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]]
elif rhs_rank == 1:
lhs_reduce_dim = -1
new_shape = lhs_shape[:lhs_reduce_dim]
else:
lhs_reduce_dim = -1
rhs_reduce_dim = -2
new_shape = self._broadcast_shapes(
lhs_shape[:-2], rhs_shape[:-2]) + [lhs_shape[-2]
] + [rhs_shape[-1]]
# merge reduce dim
self._check_merged_dims(
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
allow_broadcast=False)
if output_dtype is None:
# infer output_dtype from input type when not specified
output_dtype = self.known_vi_[node.input[
0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_dtype,
new_shape))
def _infer_ArrayFeatureExtractor(self, node):
data_shape = self._get_shape(node, 0)
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape[:-1] +
indices_shape))
def _infer_symbolic_compute_ops(self, node):
funcs = {
'Add': lambda l: l[0] + l[1],
'Div': lambda l: l[0] // l[1], # integer div in sympy
'Equal': lambda l: l[0] == l[1],
'Max':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
'Min':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])),
'Mul': lambda l: l[0] * l[1],
'Sub': lambda l: l[0] - l[1],
'Where': lambda l: l[1] if l[0] else l[2]
}
assert node.op_type in funcs
self._compute_on_sympy_data(node, funcs[node.op_type])
def _infer_Cast(self, node):
self._pass_on_sympy_data(node)
def _infer_CategoryMapper(self, node):
input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
if input_type == onnx.TensorProto.STRING:
output_type = onnx.TensorProto.INT64
else:
output_type = onnx.TensorProto.STRING
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0)))
def _infer_Transpose(self, node):
input_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm')
output_shape = np.array(input_shape)[perm].tolist()
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output
compress_len = self._new_symbolic_dim_from_output(node)
axis = get_attribute(node, 'axis')
if axis == None:
# when axis is not specified, input is flattened before compress so output is 1D
output_shape = [compress_len]
else:
output_shape = input_shape
output_shape[handle_negative_axis(axis, len(
input_shape))] = compress_len
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Concat(self, node):
if any([i in self.sympy_data_ for i in node.input]):
values = self._get_int_values(node)
if all([v is not None for v in values]):
assert 0 == get_attribute(node, 'axis')
self.sympy_data_[node.output[0]] = []
for i in range(len(node.input)):
value = values[i]
if type(value) == list:
self.sympy_data_[node.output[0]].extend(value)
else:
self.sympy_data_[node.output[0]].append(value)
sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis'), len(sympy_shape))
for i_idx in range(1, len(node.input)):
input_shape = self._get_sympy_shape(node, i_idx)
if input_shape:
sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
self._update_computed_dims(sympy_shape)
# merge symbolic dims for non-concat axes
for d in range(len(sympy_shape)):
if d == axis:
continue
dims = [
self._get_shape(node, i_idx)[d]
for i_idx in range(len(node.input))
if self._get_shape(node, i_idx)
]
if all([d == dims[0] for d in dims]):
continue
merged = self._merge_symbols(dims)
if type(merged) == str:
sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
else:
sympy_shape[d] = merged
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.
elem_type, get_shape_from_sympy_shape(sympy_shape)))
def _infer_Conv(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_ConstantOfShape(self, node):
sympy_shape = self._get_int_values(node)[0]
vi = self.known_vi_[node.output[0]]
if sympy_shape is not None:
if type(sympy_shape) != list:
sympy_shape = [sympy_shape]
self._update_computed_dims(sympy_shape)
# update sympy data if output type is int, and shape is known
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(
[is_literal(x) for x in sympy_shape]):
self.sympy_data_[node.output[0]] = np.ones(
[int(x) for x in sympy_shape],
dtype=np.int64) * numpy_helper.to_array(
get_attribute(node, 'value', 0))
else:
# create new dynamic shape
sympy_shape = self._new_symbolic_shape(
self._get_shape_rank(node, 0), node)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape)))
def _infer_Expand(self, node):
expand_to_shape = self._try_get_value(node, 1)
if expand_to_shape is not None:
# new_shape's dim can come from shape value
self._update_computed_dims(expand_to_shape)
shape = self._get_shape(node, 0)
new_shape = self._broadcast_shapes(
shape, get_shape_from_sympy_shape(expand_to_shape))
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, new_shape))
def _infer_Gather(self, node):
data_shape = self._get_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(data_shape))
indices_shape = self._get_shape(node, 1)
#if indices_shape == []:
# value = self._get_initializer_value(node, 1)
# if isinstance(value.tolist(), int):
# indices_shape = [1]
new_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:]
#print(new_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], vi.type.tensor_type.elem_type, new_shape))
if node.input[0] in self.sympy_data_:
assert 0 == get_attribute(node, 'axis',
0) # only handle 1D sympy compute
idx = self._get_value(node, 1)
data = self.sympy_data_[node.input[0]]
if type(data) == list:
if type(idx) == np.ndarray and len(idx.shape) == 1:
self.sympy_data_[node.output[0]] = [
data[int(i)] for i in idx
]
else:
self.sympy_data_[node.output[0]] = data[int(idx)]
else:
assert idx == 0
self.sympy_data_[node.output[0]] = data
def _infer_GatherElements(self, node):
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, indices_shape))
def _infer_GatherND(self, node):
data_shape = self._get_shape(node, 0)
data_rank = len(data_shape)
indices_shape = self._get_shape(node, 1)
indices_rank = len(indices_shape)
last_index_dimension = indices_shape[-1]
assert is_literal(
last_index_dimension) and last_index_dimension <= data_rank
new_shape = indices_shape[:-1] + data_shape[last_index_dimension:]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, new_shape))
def _infer_If(self, node):
# special case for constant condition, in case there are mismatching shape from the non-executed branch
subgraphs = [
get_attribute(node, 'then_branch'),
get_attribute(node, 'else_branch')
]
cond = self._try_get_value(node, 0)
if cond is not None:
if cond > 0:
subgraphs[1].CopyFrom(subgraphs[0])
else:
subgraphs[0].CopyFrom(subgraphs[1])
for i_sub, subgraph in enumerate(subgraphs):
subgraph_infer = self._onnx_infer_subgraph(
node, subgraph, use_node_input=False)
for i_out in range(len(node.output)):
vi = self.known_vi_[node.output[i_out]]
if i_sub == 0:
vi.CopyFrom(subgraph.output[i_out])
vi.name = node.output[i_out]
else:
assert all([
d1 == d2
for d1, d2 in zip(vi.type.tensor_type.shape.dim,
subgraph.output[
i_out].type.tensor_type.shape.dim)
])
# pass on sympy data from subgraph, if cond is constant
if cond is not None and i_sub == (0 if cond > 0 else 1):
if subgraph.output[
i_out].name in subgraph_infer.sympy_data_:
self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[
subgraph.output[i_out].name]
def _infer_Loop(self, node):
subgraph = get_attribute(node, 'body')
assert len(subgraph.input) == len(node.input)
for i, si in enumerate(subgraph.input):
subgraph_name = si.name
si.CopyFrom(self.known_vi_[node.input[i]])
si.name = subgraph_name
self._onnx_infer_subgraph(node, subgraph)
# create a new symbolic dimension for iteration dependent dimension
loop_iter_dim = self._new_symbolic_dim_from_output(node)
num_loop_carried = len(node.input) - 2
for i in range(len(node.output)):
vi = self.known_vi_[node.output[i]]
vi.CopyFrom(
subgraph.output[i + 1]
) # first subgraph output is condition, not in node output
if i >= num_loop_carried:
subgraph_vi_dim = subgraph.output[i +
1].type.tensor_type.shape.dim
vi.type.tensor_type.shape.ClearField('dim')
vi_dim = vi.type.tensor_type.shape.dim
vi_dim.add().dim_param = loop_iter_dim
vi_dim.extend(list(subgraph_vi_dim))
vi.name = node.output[i]
def _infer_MatMul(self, node):
self._compute_matmul_shape(node)
def _infer_MatMulInteger(self, node):
self._compute_matmul_shape(node, onnx.TensorProto.INT32)
def _infer_NonMaxSuppression(self, node):
selected = self._new_symbolic_dim_from_output(node)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], onnx.TensorProto.INT64, [selected, 3]))
def _infer_NonZero(self, node):
input_rank = self._get_shape_rank(node, 0)
# create a new symbolic dimension for NonZero output
nz_len = self._new_symbolic_dim_from_output(node, 0, 1)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
def _infer_OneHot(self, node):
shape = self._get_shape(node, 0)
axis = get_attribute(node, 'axis', -1)
axis = handle_negative_axis(axis, len(shape) + 1)
new_shape = shape[:axis] + [self._new_symbolic_dim_from_output(node)
] + shape[axis:]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[2]].type.tensor_type.elem_type, new_shape))
def _infer_Pad(self, node):
if get_opset(self.out_mp_) <= 10:
pads = get_attribute(node, 'pads')
else:
pads = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
output_shape = get_shape_from_type_proto(vi.type)
if len(output_shape) == 0 or None in output_shape:
sympy_shape = self._get_sympy_shape(node, 0)
rank = len(sympy_shape)
if pads is not None:
assert len(pads) == 2 * rank
new_sympy_shape = [
d + pad_up + pad_down
for d, pad_up, pad_down in zip(sympy_shape, pads[:rank],
pads[rank:])
]
self._update_computed_dims(new_sympy_shape)
else:
# dynamic pads, create new symbolic dimensions
new_sympy_shape = self._new_symbolic_shape(rank, node)
output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_Pool(self, node):
sympy_shape = self._compute_conv_pool_shape(node)
self._update_computed_dims(sympy_shape)
for o in node.output:
if not o:
continue
vi = self.known_vi_[o]
vi.CopyFrom(
helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(
sympy_shape)))
def _infer_Range(self, node):
vi = self.known_vi_[node.output[0]]
input_data = self._get_int_values(node)
if all([i is not None for i in input_data]):
start = as_scalar(input_data[0])
limit = as_scalar(input_data[1])
delta = as_scalar(input_data[2])
new_sympy_shape = [
sympy.Max(sympy.ceiling((limit - start) / delta), 0)
]
else:
new_dim = self._new_symbolic_dim_from_output(node)
new_sympy_shape = [self.symbolic_dims_[new_dim]]
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[0]].type.tensor_type.
elem_type, get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_ReduceProd(self, node):
axes = get_attribute(node, 'axes')
keep_dims = get_attribute(node, 'keepdims')
if keep_dims == 0 and axes == [0]:
data = self._get_int_values(node)[0]
if data is not None:
self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
def _infer_Reshape(self, node):
shape_value = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
if shape_value is None:
shape_shape = self._get_shape(node, 1)
assert len(shape_shape) == 1
shape_rank = shape_shape[0]
assert is_literal(shape_rank)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(
self._new_symbolic_shape(shape_rank, node))))
else:
input_shape = self._get_shape(node, 0)
input_sympy_shape = self._get_sympy_shape(node, 0)
total = int(1)
for d in input_sympy_shape:
total = total * d
new_sympy_shape = []
deferred_dim_idx = -1
non_deferred_size = int(1)
for i, d in enumerate(shape_value):
if type(d) == sympy.Symbol:
new_sympy_shape.append(d)
elif d == 0:
new_sympy_shape.append(input_sympy_shape[i])
non_deferred_size = non_deferred_size * input_sympy_shape[i]
else:
new_sympy_shape.append(d)
if d == -1:
deferred_dim_idx = i
elif d != 0:
non_deferred_size = non_deferred_size * d
assert new_sympy_shape.count(-1) < 2
if -1 in new_sympy_shape:
new_dim = total // non_deferred_size
new_sympy_shape[deferred_dim_idx] = new_dim
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
self._pass_on_sympy_data(node)
def _infer_Resize(self, node):
vi = self.known_vi_[node.output[0]]
input_sympy_shape = self._get_sympy_shape(node, 0)
if get_opset(self.out_mp_) <= 10:
scales = self._try_get_value(node, 1)
if scales is not None:
new_sympy_shape = [
sympy.simplify(sympy.floor(d * s))
for d, s in zip(input_sympy_shape, scales)
]
self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
else:
roi = self._try_get_value(node, 1)
scales = self._try_get_value(node, 2)
sizes = self._try_get_value(node, 3)
if sizes is not None:
new_sympy_shape = [
sympy.simplify(sympy.floor(s)) for s in sizes
]
self._update_computed_dims(new_sympy_shape)
elif roi is not None and scales is not None:
rank = len(scales)
assert len(roi) == 2 * rank
roi_start = list(roi)[:rank]
roi_end = list(roi)[rank:]
scales = list(scales)
new_sympy_shape = [
sympy.simplify(sympy.floor(d * (end - start) * scale))
for d, start, end, scale in zip(input_sympy_shape,
roi_start, roi_end, scales)
]
self._update_computed_dims(new_sympy_shape)
else:
new_sympy_shape = self._new_symbolic_shape(
self._get_shape_rank(node, 0), node)
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
new_sympy_shape)))
def _infer_Scan(self, node):
subgraph = get_attribute(node, 'body')
num_scan_inputs = get_attribute(node, 'num_scan_inputs')
scan_input_axes = get_attribute(node, 'scan_input_axes',
[0] * num_scan_inputs)
num_scan_states = len(node.input) - num_scan_inputs
scan_input_axes = [
handle_negative_axis(
ax, self._get_shape_rank(node, i + num_scan_states))
for i, ax in enumerate(scan_input_axes)
]
# We may have cases where the subgraph has optionial inputs that appear in both subgraph's input and initializer,
# but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs.
assert len(subgraph.input) >= len(node.input)
subgraph_inputs = subgraph.input[:len(node.input)]
for i, si in enumerate(subgraph_inputs):
subgraph_name = si.name
si.CopyFrom(self.known_vi_[node.input[i]])
if i >= num_scan_states:
scan_input_dim = si.type.tensor_type.shape.dim
scan_input_dim.remove(scan_input_dim[scan_input_axes[
i - num_scan_states]])
si.name = subgraph_name
self._onnx_infer_subgraph(node, subgraph)
num_scan_outputs = len(node.output) - num_scan_states
scan_output_axes = get_attribute(node, 'scan_output_axes',
[0] * num_scan_outputs)
scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[
-1]].type)[scan_input_axes[-1]]
for i, o in enumerate(node.output):
vi = self.known_vi_[o]
if i >= num_scan_states:
shape = get_shape_from_type_proto(subgraph.output[i].type)
new_dim = handle_negative_axis(
scan_output_axes[i - num_scan_states], len(shape) + 1)
shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:]
vi.CopyFrom(
helper.make_tensor_value_info(o, subgraph.output[
i].type.tensor_type.elem_type, shape))
else:
vi.CopyFrom(subgraph.output[i])
vi.name = o
def _infer_ScatterElements(self, node):
data_shape = self._get_shape(node, 0)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, data_shape))
def _infer_Shape(self, node):
self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
def _infer_Size(self, node):
sympy_shape = self._get_sympy_shape(node, 0)
self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
self.known_vi_[node.output[0]].CopyFrom(
helper.make_tensor_value_info(node.output[0],
onnx.TensorProto.INT64, []))
def _infer_Slice(self, node):
if get_opset(self.out_mp_) <= 9:
axes = get_attribute(node, 'axes')
starts = get_attribute(node, 'starts')
ends = get_attribute(node, 'ends')
steps = [1] * len(axes)
else:
starts = as_list(self._try_get_value(node, 1), keep_none=True)
ends = as_list(self._try_get_value(node, 2), keep_none=True)
axes = self._try_get_value(node, 3)
steps = self._try_get_value(node, 4)
if axes is None and not (starts is None and ends is None):
axes = list(
range(0, len(starts if starts is not None else ends)))
if steps is None and not (starts is None and ends is None):
steps = [1] * len(starts if starts is not None else ends)
axes = as_list(axes, keep_none=True)
steps = as_list(steps, keep_none=True)
new_sympy_shape = self._get_sympy_shape(node, 0)
if starts is None or ends is None:
if axes is None:
for i in range(len(new_sympy_shape)):
new_sympy_shape[i] = self._new_symbolic_dim_from_output(
node, 0, i)
else:
new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
for i in axes:
new_sympy_shape[i] = self._new_symbolic_dim_from_output(
node, 0, i)
else:
for i, s, e, t in zip(axes, starts, ends, steps):
idx = handle_negative_axis(i, len(new_sympy_shape))
if is_literal(e):
if e >= self.int_max_:
e = new_sympy_shape[i]
elif e <= -self.int_max_:
e = 0 if s > 0 else -1
elif is_literal(new_sympy_shape[i]):
if e < 0:
e = e + new_sympy_shape[i]
e = min(e, new_sympy_shape[i])
else:
if e > 0:
e = sympy.Min(
e, new_sympy_shape[i]
) if e > 1 else e #special case for slicing first to make computation easier
else:
e = new_sympy_shape[i] + e
else:
if is_literal(new_sympy_shape[i]):
e = sympy.Min(e, new_sympy_shape[i])
else:
try:
if e >= new_sympy_shape[i]:
e = new_sympy_shape[i]
except Exception:
print(
'Unable to determine if {} <= {}, treat as equal'
.format(e, new_sympy_shape[i]))
e = new_sympy_shape[i]
if is_literal(s) and int(s) < 0:
s = new_sympy_shape[i] + s
new_sympy_shape[idx] = (e - s + t + (-1 if t > 0 else 1)) // t
self._update_computed_dims(new_sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
# handle sympy_data if needed, for slice in shape computation
if node.input[0] in self.sympy_data_:
assert [0] == axes
assert len(starts) == 1
assert len(ends) == 1
self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][
starts[0]:ends[0]]
def _infer_Split(self, node):
input_sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(input_sympy_shape))
split = get_attribute(node, 'split')
if not split:
num_outputs = len(node.output)
split = [input_sympy_shape[axis] /
sympy.Integer(num_outputs)] * num_outputs
self._update_computed_dims(split)
else:
split = [sympy.Integer(s) for s in split]
for i_o in range(len(split)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[i_o], self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(input_sympy_shape[:axis] + [
split[i_o]
] + input_sympy_shape[axis + 1:])))
self.known_vi_[vi.name] = vi
def _infer_Squeeze(self, node):
self._pass_on_sympy_data(node)
def _infer_Tile(self, node):
repeats_value = self._get_value(node, 1)
input_sympy_shape = self._get_sympy_shape(node, 0)
new_sympy_shape = []
for i, d in enumerate(input_sympy_shape):
new_dim = d * repeats_value[i]
new_sympy_shape.append(new_dim)
self._update_computed_dims(new_sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_sympy_shape)))
def _infer_TopK(self, node):
rank = self._get_shape_rank(node, 0)
axis = handle_negative_axis(get_attribute(node, 'axis', -1), rank)
new_shape = self._get_shape(node, 0)
if get_opset(self.out_mp_) <= 9:
k = get_attribute(node, 'k')
else:
k = self._get_int_values(node)[1]
if k == None:
k = self._new_symbolic_dim_from_output(node)
else:
k = as_scalar(k)
if type(k) in [int, str]:
new_shape[axis] = k
else:
new_sympy_shape = self._get_sympy_shape(node, 0)
new_sympy_shape[axis] = k
self._update_computed_dims(
new_sympy_shape
) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape
new_shape = get_shape_from_sympy_shape(new_sympy_shape)
for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
i_o], vi.type.tensor_type.elem_type, new_shape))
def _infer_Unsqueeze(self, node):
self._pass_on_sympy_data(node)
def _infer_ZipMap(self, node):
map_key_type = None
if get_attribute(node, 'classlabels_int64s') is not None:
map_key_type = onnx.TensorProto.INT64
elif get_attribute(node, 'classlabels_strings') is not None:
map_key_type = onnx.TensorProto.STRING
assert map_key_type is not None
new_vi = onnx.ValueInfoProto()
new_vi.name = node.output[0]
new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(new_vi)
def _infer_impl(self, in_mp, start_sympy_data={}):
self.sympy_data_ = start_sympy_data
self.out_mp_.graph.ClearField('value_info')
self._apply_suggested_merge(graph_input_only=True)
self.input_symbols_ = set()
for i in self.out_mp_.graph.input:
input_dims = i.type.tensor_type.shape.dim
for i_dim in range(len(input_dims)):
if get_dim_from_type_proto(input_dims[i_dim]) is None:
# some models use None for symbolic dim in input, replace it with a string
input_dims[i_dim].dim_param = self._new_symbolic_dim(i.name,
i_dim)
self.input_symbols_.update([
d for d in get_shape_from_type_proto(i.type) if type(d) == str
])
for s in self.input_symbols_:
if s in self.suggested_merge_:
s_merge = self.suggested_merge_[s]
assert s_merge in self.symbolic_dims_
self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
else:
self.symbolic_dims_[s] = sympy.Symbol(s, integer=True)
# create a temporary ModelProto for single node inference
# note that we remove initializer to have faster inference
# for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways
self.tmp_mp_ = onnx.ModelProto()
self.tmp_mp_.CopyFrom(self.out_mp_)
self.tmp_mp_.graph.ClearField('initializer')
for node in self.out_mp_.graph.node:
assert all([i in self.known_vi_ for i in node.input if i])
self._onnx_infer_single_node(node)
if node.op_type in self.dispatcher_:
self.dispatcher_[node.op_type](node)
if self.verbose_ > 2:
print(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input):
print(' Input {}: {} {}€5€5€5€5€5'.format(
i, name, 'initializer'
if name in self.initializers_ else ''))
# onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
# symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
if node.op_type in [
'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger',
'MatMulInteger16', 'Where', 'Sum'
]:
vi = self.known_vi_[node.output[0]]
out_rank = len(get_shape_from_type_proto(vi.type))
in_shapes = [
self._get_shape(node, i) for i in range(len(node.input))
]
for d in range(out_rank - (2 if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'
] else 0)):
in_dims = [
s[len(s) - out_rank + d] for s in in_shapes
if len(s) + d >= out_rank
]
if len(in_dims) > 1:
self._check_merged_dims(in_dims, allow_broadcast=True)
for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]]
out_type = vi.type
out_type_kind = out_type.WhichOneof('value')
# only TensorProto and SparseTensorProto have shape
if out_type_kind != 'tensor_type' and out_type_kind != 'sparse_tensor_type':
continue
out_shape = get_shape_from_type_proto(vi.type)
out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
if self.verbose_ > 2:
print(' {}: {} {}'.format(node.output[
i_o], str(out_shape), vi.type.tensor_type.elem_type))
if node.output[i_o] in self.sympy_data_:
print(' Sympy Data: ' + str(self.sympy_data_[
node.output[i_o]]))
if None in out_shape or out_type_undefined:
if self.auto_merge_:
if node.op_type in [
'Add', 'Sub', 'Mul', 'Div', 'MatMul',
'MatMulInteger', 'MatMulInteger16', 'Concat',
'Where', 'Sum'
]:
shapes = [
self._get_shape(node, i)
for i in range(len(node.input))
]
if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'
]:
# only support auto merge for MatMul for dim < rank-2 when rank > 2
assert len(shapes[0]) > 2 and dim_idx[0] < len(
shapes[0]) - 2
assert len(shapes[1]) > 2 and dim_idx[1] < len(
shapes[1]) - 2
elif node.op_type == 'Expand':
# auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
shapes = [
self._get_shape(node, 0),
self._get_value(node, 1)
]
else:
shapes = []
if shapes:
for idx in range(len(out_shape)):
if out_shape[idx] is not None:
continue
dim_idx = [
len(s) - len(out_shape) + idx
for s in shapes
]
assert all([d >= 0 for d in dim_idx])
self._add_suggested_merge([
s[i] if is_literal(s[i]) else str(s[i])
for s, i in zip(shapes, dim_idx)
])
self.run_ = True
else:
self.run_ = False
else:
self.run_ = False
# create new dynamic dims for ops not handled by symbolic shape inference
if self.run_ == False and not node.op_type in self.dispatcher_:
is_unknown_op = (out_type_undefined and
len(out_shape) == 0)
if is_unknown_op:
# unknown op to ONNX, maybe from higher opset or other domain
# only guess the output rank from input 0 when using guess_output_rank option
out_rank = self._get_shape_rank(
node, 0) if self.guess_output_rank_ else -1
else:
# valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape
out_rank = len(out_shape)
if out_rank >= 0:
new_shape = self._new_symbolic_shape(out_rank, node,
i_o)
vi.CopyFrom(
helper.make_tensor_value_info(
vi.name, self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(new_shape)))
if self.verbose_ > 0:
if is_unknown_op:
print(
"Possible unknown op: {} node: {}, guessing {} shape"
.format(node.op_type, node.name,
vi.name))
if self.verbose_ > 2:
print(' {}: {} {}'.format(
node.output[i_o],
str(new_shape),
vi.type.tensor_type.elem_type))
self.run_ = True
continue # continue the inference after guess, no need to stop as no merge is needed
if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
print('Stopping at incomplete shape inference at ' +
node.op_type + ': ' + node.name)
print('node inputs:')
for i in node.input:
print(self.known_vi_[i])
print('node outputs:')
for o in node.output:
print(self.known_vi_[o])
if self.auto_merge_ and not out_type_undefined:
print('Merging: ' + str(self.suggested_merge_))
return False
self.run_ = False
return True
def _update_output_from_vi(self):
for output in self.out_mp_.graph.output:
if output.name in self.known_vi_:
tmp_output = self.known_vi_[output.name]
output.CopyFrom(tmp_output)
@staticmethod
def infer_shapes(in_mp,
int_max=2**31 - 1,
fixed_input_shape=None,
auto_merge=True,
guess_output_rank=False,
verbose=0):
if get_opset(in_mp) < 7:
print('Only support shape inferencing models of opset 7 and above.')
return
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(
in_mp, input_shapes=fixed_input_shape)
try:
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(
in_mp)
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
print('!' * 10)
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
except:
print('Stopping at incomplete shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict as _dict
import numpy as _np
default_op_mapping_field_values = _dict()
default_op_mapping_field_values['FLUID_OP'] = ''
default_op_mapping_field_values['FLUID_INPUT_ARGS'] = None
default_op_mapping_field_values['FLUID_OUTPUT_ARGS'] = None
default_op_mapping_field_values['ATTR_MAPPING'] = dict()
default_op_mapping_field_values['DEFAULTS'] = dict()
default_op_mapping_field_values['INPUT_PERM'] = None
default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = {
'Shape': ['shape', ['X'], ['Out']],
'Clip': [
'clip', ['X'], ['Out'], dict(), dict(
min=(_np.asarray(
[255, 255, 127, 255], dtype=_np.uint8).view(_np.float32)[0]),
max=(_np.asarray(
[255, 255, 127, 127], dtype=_np.uint8).view(_np.float32)[0]), )
],
'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [
'reduce_mean', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceSum': [
'reduce_sum', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMin': [
'reduce_min', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMax': [
'reduce_max', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
#active function
'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)],
'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)],
'ThresholdedRelu': [
'thresholded_relu', ['X'], ['Out'], dict(alpha='threshold'),
dict(alpha=1.)
],
'Tanh': ['tanh', ['X'], ['Out']],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'HardSigmoid': [
'hard_sigmoid', ['X'], ['Out'], dict(
alpha='slope', beta='offset'), dict(
slope=.2, offset=.5)
],
'Softsign': ['softsign', ['X'], ['Out']],
'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], dict(), dict(axis=1)],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Floor': ['floor', ['X'], ['Out']],
'Abs': ['abs', ['X'], ['Out']],
}
default_ioa_constraint = {
'Gather': [(lambda i, o, a: a.get('axis', 0) == 0,
'only axis = 0 is supported')],
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -12,101 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.graph import GraphNode
from x2paddle.op_mapper.onnx_opsets.opset9 import OpSet9
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode
from x2paddle.op_mapper.onnx_opsets.custom_layer import *
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
from x2paddle.op_mapper.onnx_directly_map import default_op_mapping_field_values
from x2paddle.op_mapper.onnx_directly_map import default_op_mapping
from x2paddle.op_mapper.onnx_directly_map import default_ioa_constraint
from x2paddle.op_mapper.onnx_custom_layer import *
from x2paddle.core.util import string
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
import logging as _logging
from collections import OrderedDict as _dict
import math
import os
import shutil
from functools import reduce
import onnxruntime as rt
_logger = _logging.getLogger(__name__)
def _const_weight_or_none(node):
if 'Constant' in node.layer_type:
return node.value
if isinstance(node, ONNXGraphDataNode):
return node.weight
return None
def get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size
pad0 = int(pad_size / 2)
pad1 = pad_size - pad0
return [pad0, pad1]
class ONNXOpMapper(OpMapper):
elementwise_ops = {
'Add': 'elementwise_add',
'Div': 'elementwise_div',
'Sub': 'elementwise_sub',
'Mul': 'elementwise_mul',
'Pow': 'elementwise_pow',
}
def __init__(self, decoder, save_dir):
def __init__(self, decoder):
super(ONNXOpMapper, self).__init__()
self.decoder = decoder
self.graph = decoder.onnx_graph
self.input_shapes = []
self.weights = dict()
self.omit_nodes = list()
self.used_custom_layers = dict()
self.is_inference = False
self.tmp_data_dir = os.path.join(save_dir, 'tmp_data')
self.tmp_outputs_dict = {}
self.get_output_shapes()
self.support_op_sets = [9, ]
self.default_op_set = 9
self.graph = decoder.graph
self.opset = self.create_opset(decoder)
if not self.op_checker():
raise Exception("Model are not supported yet.")
#mapping op
print("Total nodes: {}".format(
sum([
isinstance(node, ONNXGraphNode)
for name, node in self.graph.node_map.items()
])))
print("Nodes converting ...")
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
if hasattr(self, op):
func = getattr(self, op)
if hasattr(self.opset, op):
func = getattr(self.opset, op)
func(node)
elif op in default_op_mapping:
self.directly_map(node)
elif op in self.opset.default_op_mapping:
self.opset.directly_map(node)
elif op in custom_layers:
self.deal_custom_layer(node)
elif op in self.elementwise_ops:
self.elementwise_map(node)
self.remove_tmp_data()
self.opset.deal_custom_layer(node)
elif op in self.opset.elementwise_ops:
self.opset.elementwise_map(node)
print("Nodes converted.")
self.weights = self.opset.weights
self.omit_nodes = self.opset.omit_nodes
self.used_custom_layers = self.opset.used_custom_layers
def op_checker(self):
unsupported_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
if not hasattr(self, op) and \
op not in default_op_mapping and \
if not hasattr(self.opset, op) and \
op not in self.opset.default_op_mapping and \
op not in custom_layers and \
op not in self.elementwise_ops:
op not in self.opset.elementwise_ops:
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
......@@ -117,1346 +71,22 @@ class ONNXOpMapper(OpMapper):
print(op)
return False
def get_results_of_inference(self, model, value_infos, data_nodes):
if not os.path.exists(self.tmp_data_dir):
os.makedirs(self.tmp_data_dir)
inputs_dict = {}
for data_node in data_nodes:
value_info = value_infos[data_node]
shape = value_info['shape']
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
shape[i] = 1
if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned'
ipt = np.random.random(shape).astype(value_info['dtype'])
inputs_dict[data_node] = ipt
model = onnx.shape_inference.infer_shapes(model)
outputs = []
for value_info in model.graph.value_info:
outputs.append(value_info.name)
model.graph.ClearField('output')
model.graph.output.MergeFrom(model.graph.value_info)
onnx.save(model,
os.path.join(self.tmp_data_dir, 'onnx_model_infer.onnx'))
sess = rt.InferenceSession(
os.path.join(self.tmp_data_dir, 'onnx_model_infer.onnx'))
res = sess.run(None, input_feed=inputs_dict)
self.tmp_outputs_dict = dict(zip(outputs, res))
return
def get_dynamic_shape(self, layer):
"""
get dynamic shape from infer_result
"""
if layer not in self.tmp_outputs_dict:
return [None, None, None]
output = self.tmp_outputs_dict[layer]
return output.tolist(), output.dtype, output.shape
def get_output_shapes(self):
"""
build topo_sort of ONNX model
"""
nodes = self.decoder.model.graph.node
node_map = self.decoder.onnx_graph.node_map
value_infos = self.decoder.onnx_graph.value_infos
onnx_model = self.decoder.model
for layer in nodes:
node = node_map[layer.name]
for opt in layer.output:
if opt in value_infos:
value_info = value_infos[opt]
if len(value_info['shape']) == 0 or value_info[
'dtype'] is None or 0 in value_info['shape']:
if self.is_inference == False:
self.get_results_of_inference(
onnx_model, value_infos,
self.decoder.onnx_graph.place_holder_nodes)
self.is_inference = True
_, dtype, shape = self.get_dynamic_shape(opt)
node.out_shapes.append(shape)
node.dtype = dtype
else:
node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else:
if self.is_inference == False:
self.get_results_of_inference(
onnx_model, value_infos,
self.decoder.onnx_graph.place_holder_nodes)
self.is_inference = True
_, dtype, shape = self.get_dynamic_shape(opt)
node.dtype = dtype
node.out_shapes.append(shape)
def remove_tmp_data(self):
"""
remove temporarily generated file
"""
if os.path.exists(self.tmp_data_dir):
import shutil
shutil.rmtree(self.tmp_data_dir)
def directly_map(self, node, name='', *args, **kwargs):
inputs = node.layer.input
outputs = node.layer.output
op_type = node.layer_type
attrs = node.attr_map
info = default_op_mapping[op_type]
info.extend(list(default_op_mapping_field_values.values())[len(info):])
(
fluid_op,
fluid_input_args,
fluid_output_args,
attr_mapping,
default_attrs,
input_perm,
output_perm,
fill_name_field, ) = info
if fluid_op in default_ioa_constraint:
for predicate, message in default_ioa_constraint[fluid_op]:
assert predicate(inputs, outputs, attrs), message
mapped_attrs = {
attr_mapping.get(key, key): value
for key, value in attrs.items()
}
if '' in mapped_attrs:
mapped_attrs.pop('')
if '_' in mapped_attrs:
mapped_attrs.pop('_')
fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs)
inputs = inputs if input_perm is None else list(
map(lambda i: inputs[i], input_perm))
val_inps = []
for idx, ipt in enumerate(inputs):
val_inps.append(self.graph.get_input_node(node, idx=idx, copy=True))
val_outs = outputs if output_perm is None else list(
map(lambda i: outputs[i], output_perm))
attr = fluid_attrs
assert len(val_inps) == 1, 'directly_map error with multi inputs'
if fluid_op not in ['shape']:
attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(
fluid_op, inputs=val_inps[0], output=val_outs[0], param_attr=attr)
def deal_custom_layer(self, node):
op = node.layer_type
custom_code, func = make_custom_layer(node)
child_func_code, child_func = make_custom_child_func(node)
params = get_params(node.layer, node.layer_type)
arg_names, kwargs = set_args(func, params)
kwargs['name'] = string(node.layer_name)
node.fluid_code.add_layer(
func.__code__.co_name,
inputs=node.inputs,
output=node,
param_attr=kwargs,
is_custom_layer=True)
if op not in self.used_custom_layers:
self.used_custom_layers[op] = custom_code
if op + '_child_func' not in self.used_custom_layers:
if child_func_code is not None:
self.used_custom_layers[op +
'_child_func'] = child_func_code
def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_y_shape = val_y.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
if len(val_x_shape) < len(val_y_shape):
val_x, val_y = val_y, val_x
str_y_shape = ','.join(str(e) for e in val_y_shape)
str_x_shape = ','.join(str(e) for e in val_x_shape)
slice_idx = 0
if str_y_shape not in str_x_shape:
for dim in val_y_shape:
if dim == 1:
slice_idx += 1
def create_opset(self, decoder):
run_op_set = self.default_op_set
opset = ''
if decoder.op_set in self.support_op_sets:
opset = 'OpSet' + str(decoder.op_set)
elif decoder.op_set < self.default_op_set:
opset = 'OpSet' + str(self.default_op_set)
else:
for op_set in self.support_op_sets:
if decoder.op_set > op_set:
run_op_set = op_set
else:
break
attr = {"name": string(node.layer_name)}
if slice_idx < len(val_y_shape) and slice_idx > 0:
val_y_reshaped = val_y_shape[slice_idx:]
var_y_reshaped = val_y.layer_name + '_reshaped'
attr_reshaped = {
'shape': val_y_reshaped,
'name': string(var_y_reshaped)
}
node.fluid_code.add_layer(
'reshape',
inputs=val_y,
output=var_y_reshaped,
param_attr=attr_reshaped)
inputs = {'x': val_x, 'y': var_y_reshaped}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
else:
inputs = {'x': val_x, 'y': val_y}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
def place_holder(self, node):
self.input_shapes.append(node.out_shapes[0])
shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
shape[i] = 1
if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned'
attr = {
"dtype": string(node.dtype),
"shape": shape,
"name": string(node.layer_name),
"append_batch_size": 'False'
}
node.fluid_code.add_layer(
"data", inputs=None, output=node, param_attr=attr)
def create_parameter(self, node, parameter=None):
if parameter is not None:
node = parameter
dtype = node.dtype
shape = node.out_shapes[0]
if len(node.weight.shape) == 0:
shape = [1]
self.weights[node.layer_name] = node.weight
attr = {
'dtype': string(dtype),
'shape': shape,
'name': string(node.layer_name),
'attr': string(node.layer_name),
'default_initializer': 'Constant(0.0)'
}
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
def _pad_if_asymmetric(self, node, pads, val_name): # pads: SSEE
assert len(pads) & 1 == 0
symmetric = True
ndims = len(pads) // 2
for idx_dim in range(ndims):
if pads[idx_dim] != pads[ndims + idx_dim]:
symmetric = False
break
if symmetric:
return pads[:ndims], val_name
val_padded = self.Pad(node, op_independent=False)
return [0] * ndims, val_padded
def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = val_y.out_shapes[0]
if out_shape is not None:
assert len(out_shape) == 4, 'only 4-D Tensor as X and Y supported'
out_shape = out_shape[2:]
scales = _const_weight_or_none(val_scales)
if isinstance(val_scales, ONNXGraphNode):
scales, _, _ = self.get_dynamic_shape(val_scales.layer_name)
attr = {'name': string(node.layer_name)}
use_scales = True
if scales is not None:
try:
assert len(scales) == 4, 'only 4-D Tensor as X and Y supported'
assert scales[0] == 1 and scales[
1] == 1, 'only scale on (NC)HW supported'
assert scales[2] == scales[
3], 'only aspect-ratio-invariant scale supported'
except:
use_scales = False
scale = scales[2] if scales else None
if scale is None:
assert out_shape, 'neither scales nor output shape is available'
else:
if out_shape is None:
in_shape = val_x.out_shapes[0]
assert in_shape is not None, 'out_shape required but not inferrable'
assert len(
in_shape) == 4, 'only 4-D Tensor as X and Y supported'
out_shape = [in_shape[2] * scale, in_shape[3] * scale]
mode = node.get_attr('mode', 'nearest')
fluid_op = 'resize_{}'.format(mode)
if 'linear' in mode:
print(
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
)
fluid_op = 'resize_bilinear'
if use_scales and scale is not None:
attr['scale'] = scale
else:
attr['out_shape'] = out_shape
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def RoiAlign(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_rois = self.graph.get_input_node(node, idx=1, copy=True)
pooled_height = node.get_attr('output_height')
pooled_width = node.get_attr('output_width')
spatial_scale = node.get_attr('spatial_scale')
sampling_ratio = node.get_attr('sampling_ratio')
attr = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
'sampling_ratio': sampling_ratio,
}
node.fluid_code.add_layer(
'roi_align',
inputs={'input': val_x,
'rois': val_rois},
output=node,
param_attr=attr)
def MaxRoiPool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_rois = self.graph.get_input_node(node, idx=1, copy=True)
spatial_scale = node.get_attr('spatial_scale')
pooled_height, pooled_width = node.get_attr('pooled_shape')
attr = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
}
node.fluid_code.add_layer(
'roi_pool',
inputs={'input': val_x,
'rois': val_rois},
output=node,
param_attr=attr)
def Pad(self, node, op_independent=True):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
pads = node.get_attr('pads')
mode = node.get_attr('mode', 'constant')
value = node.get_attr('value', 0.)
data_shape = val_x.out_shapes[0]
output_shape = node.out_shapes[0]
assume_pad2d = False
attr = {}
if len(pads) == 4:
assume_pad2d |= mode != 'constant'
if data_shape:
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
if assume_pad2d:
fluid_op = 'pad2d'
attr['data_format'] = string('NCHW')
attr['mode'] = string(mode)
else:
attr = {'pad_value': value}
fluid_op = 'pad'
if len(pads) == 4:
paddings = np.array(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
elif len(pads) == 8:
paddings = np.array(pads).reshape(
(-1, 4)).transpose().flatten().tolist() # SSEE -> SESE
if sum(paddings[:4]) == 0:
fluid_op = 'pad2d'
paddings = paddings[4:]
attr['mode'] = string(mode)
attr['paddings'] = paddings
if op_independent:
attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
else:
attr['name'] = string(node.layer_name + '_paded')
node.fluid_code.add_layer(
fluid_op,
inputs=val_x,
output=node.layer_name + '_paded',
param_attr=attr)
return node.layer_name + '_paded'
def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
if len(val_x.out_shapes[0]) == 0:
node.fluid_code.add_layer(
'assign', inputs=val_x, output=node, param_attr=None)
else:
attr = {'axes': axes, 'name': string(node.layer_name)}
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
def Shrink(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
bias = node.get_attr('bias')
lambd = node.get_attr('lambd')
assert bias == 0.0, 'not support bias!=0'
attr = {'threshold': lambd, 'name': node.layer_name}
node.fluid_code.add_layer(
'hard_shrink', inputs=val_x, output=node, param_attr=attr)
def Greater(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
'greater_than',
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
def Constant(self, node):
val_output = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value')
dtype = np.dtype(value.dtype)
output_dtype = val_output.dtype
if output_dtype:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
shape = node.get_attr('shape', None)
if shape is None:
shape = val_output.out_shapes[0]
if shape is None:
shape = list(value.shape)
_logger.warning('in (Constant -> %s): '
'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails',
val_output.layer_name, val_output.layer_name)
if len(value) == 1:
value = value.tolist()
shape = [1]
value = value[0]
if dtype.name == 'int64':
dtype = 'int32'
attr = {'shape': shape, 'dtype': string(dtype), 'value': value}
node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr)
else:
value = np.reshape(value, shape)
self.weights[node.layer_name] = value
attr = {
'dtype': string(dtype),
'shape': shape,
'name': string(node.layer_name),
'attr': string(node.layer_name),
'default_initializer': 'Constant(0.0)'
}
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
def Resize(self, node):
self._interpolate(node)
def Upsample(self, node):
self._interpolate(node)
def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)}
node.fluid_code.add_layer(
'ones', inputs=None, output=name_ones, param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x}
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer(
'elementwise_mul',
inputs=inputs,
output=node.layer_name,
param_attr=attr)
def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0]
axis = node.get_attr('axis', 0)
assert len(
indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1:
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=node,
param_attr=None)
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer(
'gather',
inputs={'input': name_trans,
'index': indices},
output=node,
param_attr=None)
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
elif len(indices_shape) > 1:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
node.fluid_code.add_layer(
'reshape',
inputs=indices,
output=indices,
param_attr={'shape': [reshape_shape, ]})
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer(
'gather',
inputs={'input': name_trans,
'index': indices},
output=node,
param_attr=None)
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
node.fluid_code.add_layer(
'reshape',
inputs=node,
output=node,
param_attr={'shape': reshaped_shape})
def Slice(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
starts, ends, axes, steps = None, None, None, None
if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True)
if len(node.inputs) > 3:
axes = self.graph.get_input_node(node, idx=3, copy=True)
self.omit_nodes.append(axes.layer_name)
axes = _const_weight_or_none(axes)
if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True)
self.omit_nodes.append(steps.layer_name)
steps = _const_weight_or_none(steps)
self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name)
starts = _const_weight_or_none(starts).copy()
ends = _const_weight_or_none(ends).copy()
else:
starts = node.get_attr('starts')
ends = node.get_attr('ends')
axes = node.get_attr('axes')
val_y = self.graph.get_node(node.layer.output[0], copy=True)
shape = val_x.out_shapes[0]
if shape is not None:
for idx, value in enumerate(starts):
if value > shape[axes[idx]]:
starts[idx] = shape[axes[idx]]
for idx, value in enumerate(ends):
if value > shape[axes[idx]]:
ends[idx] = shape[axes[idx]]
attr = {"axes": axes, "starts": starts, "ends": ends}
node.fluid_code.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr)
def ConstantOfShape(self, node):
val_shape = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
shape = _const_weight_or_none(val_shape)
if shape is None:
shape = node.out_shapes[0]
assert shape is not None, (
'given shape is neither const value nor deductible from output, '
'this is not supported')
value = node.get_attr('value')
dtype = value.dtype
value = value.tolist()
if len(value) == 1:
shape = [1]
value = value[0]
if dtype.name == 'int64':
dtype = 'int32'
attr = {'shape': shape, 'dtype': string(dtype), 'value': value}
node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr)
def Split(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
fluid_op = 'split'
split = node.get_attr('split')
axis = node.get_attr('axis', 0)
attr = {
'num_or_sections': split,
'dim': axis,
'name': string(node.layer_name)
}
node.fluid_code.add_layer(
'split', inputs=val_x, output=val_y, param_attr=attr)
def Reshape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
shape = None
if isinstance(val_shape, ONNXGraphDataNode):
self.omit_nodes.append(val_shape.layer_name)
attr = {'name': string(node.layer_name)}
# catch dynamic graph shape
if isinstance(val_shape, ONNXGraphNode):
shape, _, _ = self.get_dynamic_shape(val_shape.layer_name)
if val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
attr['actual_shape'] = val_shape_cast
else:
attr['actual_shape'] = val_shape
if shape is None:
shape = val_reshaped.out_shapes[0]
if shape is None:
shape = [1, -1]
_logger.warning(
'in %s(%s -> Reshape -> %s): '
'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined', node.layer_name,
val_x.layer_name, val_reshaped.layer_name)
attr['shape'] = shape
node.fluid_code.add_layer(
'reshape', inputs=val_x, output=node, param_attr=attr)
def Cast(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
val_output = self.graph.get_node(node.layer.output[0], copy=True)
dtype = node.get_attr('to')
if not isinstance(dtype, np.dtype):
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = val_output.dtype
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
attr = {'dtype': string(dtype)}
node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr)
def AveragePool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr("kernel_shape")
poolnd = len(kernel_shape)
strides = node.get_attr("strides")
pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0))
pads = node.get_attr('pads', [0] * (poolnd * 2))
fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0]
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1],
strides[1])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
attr = {
"pool_size": kernel_shape,
"pool_type": string('avg'),
"pool_stride": strides,
"pool_padding": paddings,
"ceil_mode": ceil_mode,
"exclusive": 'True',
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def Concat(self, node):
inputs = []
for i in range(len(node.layer.input)):
ipt = self.graph.get_input_node(node, idx=i, copy=True)
if isinstance(ipt, str):
inputs.append(ipt)
else:
inputs.append(ipt.layer_name)
axis = node.get_attr('axis')
attr = {'axis': axis}
node.fluid_code.add_layer(
'concat', inputs=inputs, output=node, param_attr=attr)
def Flatten(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis', 1)
attr = {"axis": str(axis), "name": string(node.layer_name)}
node.fluid_code.add_layer(
'flatten', inputs=val_x, output=node, param_attr=attr)
def Gemm(self, node):
val_a = self.graph.get_input_node(node, idx=0, copy=True)
val_b = self.graph.get_input_node(node, idx=1, copy=True)
val_c = self.graph.get_input_node(node, idx=2, copy=True)
alpha = node.get_attr('alpha', 1.) # optional
beta = node.get_attr('beta', 1.) # optional
trans_a = bool(node.get_attr('transA', 0)) # optional
trans_b = bool(node.get_attr('transB', 0)) # optional
val_mm = node.layer_name + '_mm'
matmul_inputs = {"x": val_a, "y": val_b}
attr_matmul = {
"transpose_x": trans_a,
"transpose_y": trans_b,
"alpha": alpha,
"name": string(val_mm)
}
node.fluid_code.add_layer(
'matmul',
inputs=matmul_inputs,
output=val_mm,
param_attr=attr_matmul)
if beta != 0:
if beta == 1.:
add_inputs = {"x": val_mm, "y": val_c}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"elementwise_add",
inputs=add_inputs,
output=node,
param_attr=attr)
else:
var_beta = node.layer_name + '_beta'
matmul_beta_inputs = {"x": val_c, "y": var_beta}
node.fluid_code.add_layer(
"Constant",
inputs=matmul_beta_inputs,
output=var_beta,
param_attr={'value': beta})
add_inputs = {"x": val_mm, "y": var_beta}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"elementwise_add",
inputs=add_inputs,
output=node,
param_attr=attr)
def Sum(self, node):
val_inps = node.layer.input
inputs = {
"x": self.graph.get_input_node(
node, idx=0, copy=True),
"y": self.graph.get_input_node(
node, idx=1, copy=True),
}
node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node)
for idx, ipt in enumerate(val_inps[2:]):
y = self.graph.get_input_node(node, idx=idx, copy=True)
inputs = {
"x": node.layer_name,
"y": y,
}
node.fluid_code.add_layer(
"elementwise_add", inputs=inputs, output=node)
def MatMul(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
inputs = {"x": val_x, "y": val_y}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=attr)
def BatchNormalization(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True)
val_mean = self.graph.get_input_node(node, idx=3, copy=True)
val_var = self.graph.get_input_node(node, idx=4, copy=True)
self.omit_nodes.append(val_scale.layer_name)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_mean.layer_name)
self.omit_nodes.append(val_var.layer_name)
momentum = node.get_attr('momentum', .9)
epsilon = node.get_attr('epsilon', 1e-5)
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial'))
attr = {
"momentum": momentum,
"epsilon": epsilon,
"data_layout": string('NCHW'),
"is_test": True,
"param_attr": string(val_scale.layer_name),
"bias_attr": string(val_b.layer_name),
"moving_mean_name": string(val_mean.layer_name),
"moving_variance_name": string(val_var.layer_name),
"use_global_stats": spatial,
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
"batch_norm", inputs=val_x, output=node, param_attr=attr)
def Transpose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
perm = node.get_attr('perm')
attr = {'perm': perm, "name": string(node.layer_name)}
node.fluid_code.add_layer(
"transpose", inputs=val_x, output=node, param_attr=attr)
def Relu(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"relu", inputs=val_x, output=node, param_attr=attr)
def PRelu(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_input_node(node, idx=1, copy=True)
mode = 'channel'
shape_slope = val_slope.out_shapes[0]
if len(shape_slope) == 1:
mode = 'all'
elif len(shape_slope) > 2:
mode = 'element'
attr = {
"param_attr": string(val_slope.layer_name),
'mode': string(mode)
}
node.fluid_code.add_layer(
"prelu", inputs=val_x, output=node, param_attr=attr)
def Squeeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
attr = {'axes': axes, "name": string(node.layer_name)}
node.fluid_code.add_layer(
"squeeze", inputs=val_x, output=node, param_attr=attr)
def Equal(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
"equal",
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
def Where(self, node):
condition = self.graph.get_input_node(node, idx=0, copy=True)
val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.layer_name + '_not'
node.fluid_code.add_layer(
"logical_not",
inputs=condition,
output=not_condition,
param_attr=None)
cast_not_condition = not_condition + '_cast'
node.fluid_code.add_layer(
"cast",
inputs=not_condition,
output=cast_not_condition,
param_attr={'dtype': string(val_x.dtype)})
cast_condition = condition.layer_name + '_cast'
node.fluid_code.add_layer(
"cast",
inputs=condition,
output=cast_condition,
param_attr={'dtype': string(val_x.dtype)})
mul_val_x = val_x.layer_name + '_mul'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={'x': val_x,
'y': cast_condition},
output=mul_val_x,
param_attr=None)
mul_val_y = val_y.layer_name + '_mul'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={'x': val_y,
'y': cast_not_condition},
output=mul_val_y,
param_attr=None)
node.fluid_code.add_layer(
"elementwise_add",
inputs={'x': mul_val_x,
'y': mul_val_y},
output=node,
param_attr=None)
def NonZero(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
where_name = node.layer_name + '_where'
node.fluid_code.add_layer(
"where", inputs=val_x.layer_name + '!=0', output=where_name)
dims = len(val_x.out_shapes[0])
elements_count_val_x = reduce(lambda x, y: x * y, val_x.out_shapes[0])
flatten_names = []
for dim in range(dims):
slice_name = node.layer_name + '_slice' + str(dim)
flatten_name = node.layer_name + '_flatten' + str(dim)
flatten_names.append(flatten_name)
attr = {
'axes': list(range(dims)),
'starts': [0, dim],
'ends': [elements_count_val_x, dim + 1]
}
node.fluid_code.add_layer(
"slice", inputs=where_name, output=slice_name, param_attr=attr)
node.fluid_code.add_layer(
"flatten",
inputs=slice_name,
output=flatten_name,
param_attr={'axis': 0})
node.fluid_code.add_layer(
"concat", inputs=flatten_names, output=node,
param_attr={'axis': 0})
def Identity(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer("assign", inputs=val_x, output=node)
def Tile(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_repeats = self.graph.get_input_node(node, idx=1, copy=True)
repeats = _const_weight_or_none(val_repeats)
assert repeats is not None, 'for OP:Tile, only const repeats supported'
if isinstance(repeats, int):
repeats = [repeats]
attr = {
'expand_times': repeats,
"name": string(node.layer_name),
}
node.fluid_code.add_layer(
"expand", inputs=val_x, output=node, param_attr=attr)
def MaxPool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
assert node.get_attr(
"dilations") is None, 'only dilations = 0 is supported' # optional
kernel_shape = node.get_attr("kernel_shape")
poolnd = len(kernel_shape)
strides = node.get_attr("strides")
pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0)) # optional
pads = node.get_attr('pads', [0] * (poolnd * 2)) # optional
fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0]
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1],
strides[1])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
attr = {
"pool_size": kernel_shape,
"pool_type": string("max"),
"pool_stride": strides,
"pool_padding": paddings,
"ceil_mode": ceil_mode,
"name": string(node.layer_name),
"exclusive": False
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def _global_pool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
input_shape = val_x.out_shapes[0]
output_shape = val_y.out_shapes[0]
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # N
if input_shape:
poolnd = len(input_shape) - 2 # NC...
elif output_shape:
poolnd = len(output_shape) - 2 # NC...
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
fluid_op = 'pool{}d'.format(poolnd)
pool_type = None
if node.layer.op_type == 'GlobalMaxPool':
pool_type = 'max'
elif node.layer.op_type == 'GlobalAveragePool':
pool_type = 'avg'
attr = {
"pool_type": string(pool_type),
"global_pooling": True,
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def GlobalMaxPool(self, node):
self._global_pool(node)
def GlobalAveragePool(self, node):
self._global_pool(node)
def Conv(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
self.omit_nodes.append(val_w.layer_name)
has_bias = len(node.layer.input) == 3
if has_bias:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr('kernel_shape')
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = val_w.out_shapes[0][0] # OI...
fluid_op = 'conv{}d'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd) # optional
dilations = node.get_attr('dilations', [1] * convnd) # optional
pads = node.get_attr('pads', [0] * (convnd * 2)) # optional
input_shape = val_x.out_shapes[0]
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1],
strides[1])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
attr = {
"num_filters": num_out_channels,
"filter_size": kernel_shape,
"stride": strides,
"padding": paddings,
"dilation": dilations,
"groups": num_groups,
'param_attr': string(val_w.layer_name),
"name": string(node.layer_name)
}
if has_bias:
attr["bias_attr"] = string(val_b.layer_name)
else:
attr["bias_attr"] = False
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def ConvTranspose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_b = None
if len(node.layer.input) > 2:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_w.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
out_padding = node.get_attr('output_padding', [0, 0])
kernel_shape = node.get_attr('kernel_shape')
assert kernel_shape, 'kernel_shape not inferred'
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_out_channels = val_w.out_shapes[0][1]
fluid_op = 'conv{}d_transpose'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd)
dilations = node.get_attr('dilations', [1] * convnd)
output_size = node.get_attr('output_shape', [])
pads = node.get_attr('pads', [0] * (convnd * 2))
paddings, var_x = self._pad_if_asymmetric(node, pads, val_x)
output_size = [0, 0]
output_size[0] = (val_x.out_shapes[0][2] - 1
) * strides[0] - 2 * paddings[0] + dilations[0] * (
kernel_shape[0] - 1) + 1 + out_padding[0]
output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1]
attr = {
'num_filters': num_out_channels,
'output_size': output_size or None,
'filter_size': kernel_shape,
'padding': paddings,
'stride': strides,
'dilation': dilations,
'groups': num_groups,
'param_attr': string(val_w.layer_name),
'bias_attr': None if val_b is None else string(val_b.layer_name),
'name': string(node.layer_name),
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def GRU(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_r = self.graph.get_input_node(node, idx=2, copy=True)
val_b = None
val_len = None
val_xh = None
miss_arg_num = 0
num_ipt = len(node.layer.input)
if num_ipt > 3 and node.layer.input[3] != '':
val_b = self.graph.get_input_node(node, idx=3, copy=True)
else:
miss_arg_num += 1
if num_ipt > 4 and node.layer.input[4] != '':
val_len = self.graph.get_input_node(
node, idx=4 - miss_arg_num, copy=True)
else:
miss_arg_num += 1
if num_ipt > 5 and node.layer.input[5] != '':
val_xh = self.graph.get_input_node(
node, idx=5 - miss_arg_num, copy=True)
data, dtype, shape = self.get_dynamic_shape(val_x.layer_name)
x_shape = val_x.out_shapes[0]
assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert node.get_attr('clip', None) is None, 'clipping not supported'
hidden_size = node.get_attr('hidden_size', None)
if hidden_size is None:
r_shape = val_r.out_shapes[0]
if r_shape:
hidden_size = r_shape[-1]
if hidden_size is None:
w_shape = var_w.out_shapes[0]
if w_shape:
hidden_size = w_shape[-2] // 3
if hidden_size is None and val_b:
b_shape = val_b.out_shapes[0]
if b_shape:
hidden_size = b_shape[-1] // 6
if hidden_size is None and val_xh:
xh_shape = val_xh.out_shapes[0]
if xh_shape:
hidden_size = xh_shape[-1]
direction = node.get_attr('direction', 'forward')
assert direction != 'bidirectional', 'direction = bidirectional not supported'
activations = node.get_attr('activations', ['Sigmoid', 'Tanh'])
assert len(activations) == 2, 'bidirectional operation not supported'
assert node.get_attr('linear_before_reset',
0) == 0, 'only linear_before_reset = 0 supported'
activations = [s.lower() for s in activations]
gate_activation, candidate_activation = activations
is_reverse = direction == 'reverse'
var_x0 = node.layer_name + '_x0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_x,
output=var_x0,
param_attr={'axes': [1],
'name': string(var_x0)})
var_w0 = node.layer_name + '_w0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_w,
output=var_w0,
param_attr={'axes': [0],
'name': string(var_w0)})
var_fc = node.layer_name + '_fc'
var_mm = (node.layer_name + '_mm') if val_b else var_fc
node.fluid_code.add_layer(
'matmul',
inputs={'x': var_x0,
'y': var_w0},
output=var_mm,
param_attr={
'transpose_x': 0,
'transpose_y': 1,
'name': string(var_mm)
})
var_r0 = node.layer_name + '_r0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_r,
output=var_r0,
param_attr={'axes': [0],
'name': string(var_r0)})
var_r0t = node.layer_name + '_r0t'
node.fluid_code.add_layer(
'transpose',
inputs=var_r0,
output=var_r0t,
param_attr={'perm': [1, 0],
'name': string(var_r0t)})
if val_b:
var_bi = node.layer_name + '_bi'
var_bh = node.layer_name + '_bh'
node.fluid_code.add_layer(
'split',
inputs=val_b,
output=var_bi + ',' + var_bh,
param_attr={
'axis': 1,
'split': [hidden_size * 3, hidden_size * 3],
'name': string(node.layer_name + '.b/split')
})
var_bi0 = node.layer_name + '_bi0'
node.fluid_code.add_layer(
'squeeze',
inputs=var_bi,
output=var_bi0,
param_attr={'axes': [0],
'name': string(var_bi0)})
node.fluid_code.add_layer(
'elmentwise_add',
inputs=[var_mm, var_bi0],
output=var_fc,
param_attr={
'axes': 1,
'name': string(node.layer_name + '.i/bias')
})
if val_xh:
var_xh0 = node.layer_name + '_xh0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_xh,
output=var_xh0,
param_attr={'axes': [1],
'name': string(var_xh0)})
var_y00 = node.layer_name + '_y00'
attr = {
'origin_mode': True,
'h_0': var_xh0 if val_xh else None,
'is_reverse': is_reverse,
'gate_activation': string(gate_activation),
'candidate_activation': string(candidate_activation),
'param_attr': string(var_r0t),
'bias_attr': string(var_bh) if val_b else False,
}
node.fluid_code.add_layer(
'dynamic_gru',
inputs=var_fc + ',' + str(hidden_size),
output=var_y00,
param_attr=attr)
num_opt = len(node.layer.output)
if num_opt > 0 and node.layer.output[0] != '':
node.fluid_code.add_layer(
'unsqueeze',
inputs=var_y00,
output=node.layer.output[0],
param_attr={
'axes': [1, 1],
'name': string(node.layer.output[0])
})
if num_opt > 1 and node.layer.output[1] != '':
node.fluid_code.add_layer(
'unsqueeze',
inputs=var_y00,
output=node.layer.output[1],
param_attr={
'axes': [1, 1],
'name': string(node.layer.output[1])
})
opset = 'OpSet' + str(run_op_set)
print(
'Now, onnx2paddle support convert onnx model opset_verison {},'
'opset_verison of your onnx model is {}, automatically treated as op_set: {}.'
.format(self.support_op_sets, decoder.op_set, run_op_set))
return eval(opset)(decoder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
from x2paddle.core.graph import GraphNode
from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import string
from functools import reduce
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
import logging as _logging
from collections import OrderedDict
import math
import os
import shutil
_logger = _logging.getLogger(__name__)
def _const_weight_or_none(node):
if 'Constant' in node.layer_type:
return node.value
if isinstance(node, ONNXGraphDataNode):
return node.weight
return None
def get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size
pad0 = int(pad_size / 2)
pad1 = pad_size - pad0
return [pad0, pad1]
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
print("convert failed node:{}, op_type is {}".format(
node.layer_name[9:], node.layer_type))
raise
else:
#print("convert successfully node:{}, op_type is {}".format(
# node.layer_name[9:], node.layer_type))
return res
return run_mapping
class OpSet9():
elementwise_ops = {
'Add': 'elementwise_add',
'Div': 'elementwise_div',
'Sub': 'elementwise_sub',
'Mul': 'elementwise_mul',
'Pow': 'elementwise_pow',
}
default_op_mapping_field_values = OrderedDict()
default_op_mapping_field_values['FLUID_OP'] = ''
default_op_mapping_field_values['FLUID_INPUT_ARGS'] = None
default_op_mapping_field_values['FLUID_OUTPUT_ARGS'] = None
default_op_mapping_field_values['ATTR_MAPPING'] = dict()
default_op_mapping_field_values['DEFAULTS'] = dict()
default_op_mapping_field_values['INPUT_PERM'] = None
default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = {
'Shape': ['shape', ['X'], ['Out']],
'Clip': [
'clip', ['X'], ['Out'], dict(), dict(
min=(np.asarray(
[255, 255, 127, 255], dtype=np.uint8).view(np.float32)[0]),
max=(np.asarray(
[255, 255, 127, 127], dtype=np.uint8).view(np.float32)[0]),
)
],
'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [
'reduce_mean', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceSum': [
'reduce_sum', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMin': [
'reduce_min', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
#active function
'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)],
'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)],
'ThresholdedRelu': [
'thresholded_relu', ['X'], ['Out'], dict(alpha='threshold'),
dict(alpha=1.)
],
'Tanh': ['tanh', ['X'], ['Out']],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'HardSigmoid': [
'hard_sigmoid', ['X'], ['Out'], dict(
alpha='slope', beta='offset'), dict(
slope=.2, offset=.5)
],
'Softsign': ['softsign', ['X'], ['Out']],
'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], dict(), dict(axis=1)],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Floor': ['floor', ['X'], ['Out']],
'Abs': ['abs', ['X'], ['Out']],
}
default_ioa_constraint = {
'Gather':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported')],
}
def __init__(self, decoder):
super(OpSet9, self).__init__()
self.graph = decoder.graph
self.input_shapes = []
self.weights = dict()
self.omit_nodes = list()
self.used_custom_layers = dict()
@print_mapping_info
def directly_map(self, node, name='', *args, **kwargs):
inputs = node.layer.input
outputs = node.layer.output
op_type = node.layer_type
attrs = node.attr_map
info = self.default_op_mapping[op_type]
info.extend(
list(self.default_op_mapping_field_values.values())[len(info):])
(
fluid_op,
fluid_input_args,
fluid_output_args,
attr_mapping,
default_attrs,
input_perm,
output_perm,
fill_name_field, ) = info
if fluid_op in self.default_ioa_constraint:
for predicate, message in self.default_ioa_constraint[fluid_op]:
assert predicate(inputs, outputs, attrs), message
mapped_attrs = {
attr_mapping.get(key, key): value
for key, value in attrs.items()
}
if '' in mapped_attrs:
mapped_attrs.pop('')
if '_' in mapped_attrs:
mapped_attrs.pop('_')
fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs)
inputs = inputs if input_perm is None else list(
map(lambda i: inputs[i], input_perm))
val_inps = []
for idx, ipt in enumerate(inputs):
val_inps.append(self.graph.get_input_node(node, idx=idx, copy=True))
val_outs = outputs if output_perm is None else list(
map(lambda i: outputs[i], output_perm))
attr = fluid_attrs
assert len(val_inps) == 1, 'directly_map error with multi inputs'
if fluid_op not in ['shape', 'erf']:
attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(
fluid_op, inputs=val_inps[0], output=val_outs[0], param_attr=attr)
if fluid_op in ['shape']:
node.fluid_code.add_layer(
'cast',
inputs=val_outs[0],
output=val_outs[0],
param_attr={'dtype': string('int64')})
@print_mapping_info
def deal_custom_layer(self, node):
op = node.layer_type
custom_code, func = make_custom_layer(node)
child_func_code, child_func = make_custom_child_func(node)
params = get_params(node.layer, node.layer_type)
arg_names, kwargs = set_args(func, params)
kwargs['name'] = string(node.layer_name)
node.fluid_code.add_layer(
func.__code__.co_name,
inputs=node.inputs,
output=node,
param_attr=kwargs,
is_custom_layer=True)
if op not in self.used_custom_layers:
self.used_custom_layers[op] = custom_code
if op + '_child_func' not in self.used_custom_layers:
if child_func_code is not None:
self.used_custom_layers[op +
'_child_func'] = child_func_code
@print_mapping_info
def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_y_shape = val_y.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
if len(val_x_shape) < len(val_y_shape):
val_x, val_y = val_y, val_x
val_y_shape, val_x_shape = val_x_shape, val_y_shape
str_y_shape = ','.join(str(e) for e in val_y_shape)
str_x_shape = ','.join(str(e) for e in val_x_shape)
slice_idx = 0
if str_y_shape not in str_x_shape:
for dim in val_y_shape:
if dim == 1:
slice_idx += 1
else:
break
attr = {"name": string(node.layer_name)}
if slice_idx < len(val_y_shape) and slice_idx > 0:
val_y_reshaped = val_y_shape[slice_idx:]
var_y_reshaped = val_y.layer_name + '_reshaped'
attr_reshaped = {
'shape': val_y_reshaped,
'name': string(var_y_reshaped)
}
node.fluid_code.add_layer(
'reshape',
inputs=val_y,
output=var_y_reshaped,
param_attr=attr_reshaped)
inputs = {'x': val_x, 'y': var_y_reshaped}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
else:
inputs = {'x': val_x, 'y': val_y}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
@print_mapping_info
def place_holder(self, node):
self.input_shapes.append(node.out_shapes[0])
shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
shape[i] = 1
if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned'
attr = {
"dtype": string(node.dtype),
"shape": shape,
"name": string(node.layer_name),
"append_batch_size": 'False'
}
node.fluid_code.add_layer(
"data", inputs=None, output=node, param_attr=attr)
@print_mapping_info
def create_parameter(self, node, parameter=None):
if parameter is not None:
node = parameter
dtype = node.dtype
shape = node.out_shapes[0]
if len(node.weight.shape) == 0:
shape = [1]
self.weights[node.layer_name] = node.weight
attr = {
'dtype': string(dtype),
'shape': shape,
'name': string(node.layer_name),
'default_initializer': 'Constant(0.0)'
}
if dtype == 'bool':
attr['dtype'] = string('int64')
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
node.fluid_code.add_layer(
"cast",
inputs=node,
output=node,
param_attr={'dtype': string('bool')})
elif dtype == 'uint8':
attr['dtype'] = string('float32')
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
else:
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
def _pad_if_asymmetric(self, node, pads, val_name): # pads: SSEE
assert len(pads) & 1 == 0
symmetric = True
ndims = len(pads) // 2
for idx_dim in range(ndims):
if pads[idx_dim] != pads[ndims + idx_dim]:
symmetric = False
break
if symmetric:
return pads[:ndims], val_name
val_padded = self.Pad(node, op_independent=False)
return [0] * ndims, val_padded
def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
if node.layer_type == 'Resize':
val_scales = self.graph.get_input_node(node, idx=2, copy=True)
elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
attr = {'name': string(node.layer_name)}
mode = node.get_attr('mode', 'nearest')
fluid_op = 'resize_{}'.format(mode)
if 'linear' in mode:
print(
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
)
fluid_op = 'resize_bilinear'
node.fluid_code.add_layer(
fluid_op,
inputs={'input': val_x,
'scale': val_scales},
output=node,
param_attr=attr)
@print_mapping_info
def RoiAlign(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_rois = self.graph.get_input_node(node, idx=1, copy=True)
pooled_height = node.get_attr('output_height')
pooled_width = node.get_attr('output_width')
spatial_scale = node.get_attr('spatial_scale')
sampling_ratio = node.get_attr('sampling_ratio')
attr = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
'sampling_ratio': sampling_ratio,
}
node.fluid_code.add_layer(
'roi_align',
inputs={'input': val_x,
'rois': val_rois},
output=node,
param_attr=attr)
@print_mapping_info
def MaxRoiPool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_rois = self.graph.get_input_node(node, idx=1, copy=True)
spatial_scale = node.get_attr('spatial_scale')
pooled_height, pooled_width = node.get_attr('pooled_shape')
attr = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
}
node.fluid_code.add_layer(
'roi_pool',
inputs={'input': val_x,
'rois': val_rois},
output=node,
param_attr=attr)
@print_mapping_info
def Pad(self, node, op_independent=True):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
pads = node.get_attr('pads')
mode = node.get_attr('mode', 'constant')
value = node.get_attr('value', 0.)
data_shape = val_x.out_shapes[0]
output_shape = node.out_shapes[0]
assume_pad2d = False
attr = {}
if len(pads) == 4:
assume_pad2d |= mode != 'constant'
if data_shape:
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
if assume_pad2d:
fluid_op = 'pad2d'
attr['data_format'] = string('NCHW')
attr['mode'] = string(mode)
else:
attr = {'pad_value': value}
fluid_op = 'pad'
if len(pads) == 4:
paddings = np.array(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
elif len(pads) == 8:
paddings = np.array(pads).reshape(
(-1, 4)).transpose().flatten().tolist() # SSEE -> SESE
if sum(paddings[:4]) == 0:
fluid_op = 'pad2d'
paddings = paddings[4:]
attr['mode'] = string(mode)
attr['paddings'] = paddings
if op_independent:
attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
else:
attr['name'] = string(node.layer_name + '_paded')
node.fluid_code.add_layer(
fluid_op,
inputs=val_x,
output=node.layer_name + '_paded',
param_attr=attr)
return node.layer_name + '_paded'
@print_mapping_info
def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
attr = {'axes': axes, 'name': string(node.layer_name)}
if len(val_x.out_shapes[0]) == 0:
if node.layer_name:
node.fluid_code.add_layer(
'reshape',
inputs=val_x,
output=node,
param_attr={'shape': [1]})
else:
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Shrink(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
bias = node.get_attr('bias')
lambd = node.get_attr('lambd')
assert bias == 0.0, 'not support bias!=0'
attr = {'threshold': lambd, 'name': node.layer_name}
node.fluid_code.add_layer(
'hard_shrink', inputs=val_x, output=node, param_attr=attr)
def Greater(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
'greater_than',
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
@print_mapping_info
def Constant(self, node):
val_output = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value')
dtype = np.dtype(value.dtype)
output_dtype = val_output.dtype
if output_dtype:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
shape = node.get_attr('shape', None)
if shape is None:
shape = val_output.out_shapes[0]
if shape is None:
shape = list(value.shape)
_logger.warning('in (Constant -> %s): '
'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails',
val_output.layer_name, val_output.layer_name)
if len(value) == 1:
value = value.tolist()
shape = [1]
value = value[0]
if dtype.name == 'int64':
dtype = 'int32'
attr = {'shape': shape, 'dtype': string(dtype), 'value': value}
node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr)
else:
if dtype.name == 'uint8':
dtype = 'int64'
value = np.reshape(value, shape)
self.weights[node.layer_name] = value
attr = {
'dtype': string(dtype),
'shape': shape,
'name': string(node.layer_name),
'default_initializer': 'Constant(0.0)'
}
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
@print_mapping_info
def Resize(self, node):
self._interpolate(node)
@print_mapping_info
def Upsample(self, node):
self._interpolate(node)
@print_mapping_info
def InstanceNormalization(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True)
epsilon = node.get_attr('epsilon', 1e-5)
attr = {
'epsilon': epsilon,
'param_attr': string(val_scale.layer_name),
'bias_attr': string(val_b.layer_name)
}
node.fluid_code.add_layer(
"instance_norm", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)}
node.fluid_code.add_layer(
'ones', inputs=None, output=name_ones, param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x}
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer(
'elementwise_mul',
inputs=inputs,
output=node.layer_name,
param_attr=attr)
@print_mapping_info
def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0]
axis = node.get_attr('axis', 0)
#assert len(
# indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1:
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=node,
param_attr=None)
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer(
'gather',
inputs={'input': name_trans,
'index': indices},
output=node,
param_attr=None)
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
node.fluid_code.add_layer(
'embedding',
inputs=indices,
output=node,
use_fluid=True,
param_attr={
'param_attr': string(val_x.layer_name),
'size': val_x.out_shapes[0]
})
else:
from functools import reduce
#indices_shape = [1,7]
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer(
'reshape',
inputs=indices,
output=indices_reshape,
param_attr={'shape': [reshape_shape, ]})
perm = list(range(len(val_x.out_shapes[0])))
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices_reshape},
output=node,
param_attr=None)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
node.fluid_code.add_layer(
'reshape',
inputs=node,
output=node,
param_attr={'shape': reshaped_shape})
elif axis > 0 and len(indices_shape) > 1:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer(
'reshape',
inputs=indices,
output=indices_reshape,
param_attr={'shape': [reshape_shape, ]})
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer(
'gather',
inputs={'input': name_trans,
'index': indices_reshape},
output=node,
param_attr=None)
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
node.fluid_code.add_layer(
'reshape',
inputs=node,
output=node,
param_attr={'shape': reshaped_shape})
@print_mapping_info
def Range(self, node):
val_start = self.graph.get_input_node(node, idx=0, copy=True)
val_limit = self.graph.get_input_node(node, idx=1, copy=True)
val_delta = self.graph.get_input_node(node, idx=2, copy=True)
dtype = val_start.dtype
inputs = {'start': val_start, 'end': val_limit, 'step': val_delta}
node.fluid_code.add_layer(
'range',
inputs=inputs,
output=node,
param_attr={'dtype': string(dtype)})
@print_mapping_info
def Slice(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
starts, ends, axes, steps = None, None, None, None
attr = {}
if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True)
if len(node.inputs) > 3:
axes = self.graph.get_input_node(node, idx=3, copy=True)
axes = _const_weight_or_none(axes)
if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True)
steps = _const_weight_or_none(steps)
if steps is not None:
assert steps == 1, "Only support convert op:Slice, which attribute:steps == 1"
attr = {
"axes": axes,
"starts": starts.layer_name,
"ends": ends.layer_name
}
starts_value = _const_weight_or_none(starts)
ends_value = _const_weight_or_none(ends)
if starts_value is not None and ends_value is not None:
self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name)
ends_value = ends_value.copy()
for idx in range(len(ends_value)):
if ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1
attr = {
"axes": axes,
"starts": starts_value,
"ends": ends_value
}
else:
if starts.dtype != 'int32':
node.fluid_code.add_layer(
'cast',
inputs=starts,
output=starts,
param_attr={'dtype': string('int32')})
if ends.dtype != 'int32':
node.fluid_code.add_layer(
'cast',
inputs=ends,
output=ends,
param_attr={'dtype': string('int32')})
else:
starts = node.get_attr('starts')
ends = node.get_attr('ends')
axes = node.get_attr('axes')
for idx in range(len(ends)):
if ends[idx] > 2**31 - 1:
ends[idx] = 2**31 - 1
attr = {"axes": axes, "starts": starts, "ends": ends}
node.fluid_code.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def ConstantOfShape(self, node):
val_shape = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value')
dtype = value.dtype
value = value.tolist()
assert len(value) == 1, ('given value not Scalar, shape of value > 1, '
'this is not supported')
if len(value) == 1:
value = value[0]
if dtype.name == 'int64':
dtype = 'int32'
attr = {
'shape': val_shape.layer_name,
'dtype': string(dtype),
'value': value
}
node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr)
@print_mapping_info
def Split(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
fluid_op = 'split'
split = node.get_attr('split')
axis = node.get_attr('axis', 0)
attr = {
'num_or_sections': split,
'dim': axis,
'name': string(node.layer_name)
}
node.fluid_code.add_layer(
'split', inputs=val_x, output=val_y, param_attr=attr)
@print_mapping_info
def Reshape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
attr = {}
shape_value = _const_weight_or_none(val_shape)
shape_dims = len(val_shape.out_shapes[0])
if shape_value is not None:
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x},
output=node,
param_attr={'shape': shape_value.tolist()})
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
node.fluid_code.add_layer(
'reshape',
inputs=val_shape_cast,
output=val_shape_cast,
param_attr={'shape': val_shape.out_shapes[0]})
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': val_shape_cast},
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer(
'reshape',
inputs=val_shape,
output=val_shape,
param_attr={'shape': val_shape.out_shapes[0]})
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': val_shape},
output=node,
param_attr=attr)
@print_mapping_info
def Cast(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
val_output = self.graph.get_node(node.layer.output[0], copy=True)
dtype = node.get_attr('to')
if not isinstance(dtype, np.dtype):
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = val_output.dtype
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
attr = {'dtype': string(dtype)}
node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr)
@print_mapping_info
def AveragePool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr("kernel_shape")
poolnd = len(kernel_shape)
strides = node.get_attr("strides")
pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0))
pads = node.get_attr('pads', [0] * (poolnd * 2))
fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0]
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1],
strides[1])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
attr = {
"pool_size": kernel_shape,
"pool_type": string('avg'),
"pool_stride": strides,
"pool_padding": paddings,
"ceil_mode": ceil_mode,
"exclusive": 'True',
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Concat(self, node):
inputs = []
for i in range(len(node.layer.input)):
ipt = self.graph.get_input_node(node, idx=i, copy=True)
if isinstance(ipt, str):
inputs.append(ipt)
else:
inputs.append(ipt.layer_name)
axis = node.get_attr('axis')
attr = {'axis': axis}
node.fluid_code.add_layer(
'concat', inputs=inputs, output=node, param_attr=attr)
@print_mapping_info
def Flatten(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis', 1)
attr = {"axis": str(axis), "name": string(node.layer_name)}
node.fluid_code.add_layer(
'flatten', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Gemm(self, node):
val_a = self.graph.get_input_node(node, idx=0, copy=True)
val_b = self.graph.get_input_node(node, idx=1, copy=True)
val_c = self.graph.get_input_node(node, idx=2, copy=True)
alpha = node.get_attr('alpha', 1.) # optional
beta = node.get_attr('beta', 1.) # optional
trans_a = bool(node.get_attr('transA', 0)) # optional
trans_b = bool(node.get_attr('transB', 0)) # optional
val_mm = node.layer_name + '_mm'
matmul_inputs = {"x": val_a, "y": val_b}
attr_matmul = {
"transpose_x": trans_a,
"transpose_y": trans_b,
"alpha": alpha,
"name": string(val_mm)
}
node.fluid_code.add_layer(
'matmul',
inputs=matmul_inputs,
output=val_mm,
param_attr=attr_matmul)
if beta != 0:
if beta == 1.:
add_inputs = {"x": val_mm, "y": val_c}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"elementwise_add",
inputs=add_inputs,
output=node,
param_attr=attr)
else:
var_beta = node.layer_name + '_beta'
matmul_beta_inputs = {"x": val_c, "y": var_beta}
node.fluid_code.add_layer(
"Constant",
inputs=matmul_beta_inputs,
output=var_beta,
param_attr={'value': beta})
add_inputs = {"x": val_mm, "y": var_beta}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"elementwise_add",
inputs=add_inputs,
output=node,
param_attr=attr)
@print_mapping_info
def Sum(self, node):
val_inps = node.layer.input
inputs = {
"x": self.graph.get_input_node(
node, idx=0, copy=True),
"y": self.graph.get_input_node(
node, idx=1, copy=True),
}
node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node)
for idx, ipt in enumerate(val_inps[2:]):
y = self.graph.get_input_node(node, idx=idx, copy=True)
inputs = {
"x": node.layer_name,
"y": y,
}
node.fluid_code.add_layer(
"elementwise_add", inputs=inputs, output=node)
@print_mapping_info
def MatMul(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
inputs = {"x": val_x, "y": val_y}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=attr)
@print_mapping_info
def BatchNormalization(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True)
val_mean = self.graph.get_input_node(node, idx=3, copy=True)
val_var = self.graph.get_input_node(node, idx=4, copy=True)
self.omit_nodes.append(val_scale.layer_name)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_mean.layer_name)
self.omit_nodes.append(val_var.layer_name)
momentum = node.get_attr('momentum', .9)
epsilon = node.get_attr('epsilon', 1e-5)
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial'))
attr = {
"momentum": momentum,
"epsilon": epsilon,
"data_layout": string('NCHW'),
"is_test": True,
"param_attr": string(val_scale.layer_name),
"bias_attr": string(val_b.layer_name),
"moving_mean_name": string(val_mean.layer_name),
"moving_variance_name": string(val_var.layer_name),
"use_global_stats": spatial,
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
"batch_norm", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Transpose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
perm = node.get_attr('perm')
attr = {'perm': perm, "name": string(node.layer_name)}
node.fluid_code.add_layer(
"transpose", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Relu(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"relu", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def PRelu(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_input_node(node, idx=1, copy=True)
mode = 'channel'
shape_slope = val_slope.out_shapes[0]
if len(shape_slope) == 1:
mode = 'all'
elif len(shape_slope) > 2:
mode = 'element'
attr = {
"param_attr": string(val_slope.layer_name),
'mode': string(mode)
}
node.fluid_code.add_layer(
"prelu", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Squeeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
attr = {'axes': axes, "name": string(node.layer_name)}
if len(val_x.out_shapes[0]) == 1:
node.fluid_code.add_layer(
"cast",
inputs=val_x,
output=node,
param_attr={'dtype': string(val_x.dtype)})
else:
node.fluid_code.add_layer(
"squeeze", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Equal(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
"equal",
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
@print_mapping_info
def Where(self, node):
condition = self.graph.get_input_node(node, idx=0, copy=True)
val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.layer_name + '_not'
node.fluid_code.add_layer(
"logical_not",
inputs=condition,
output=not_condition,
param_attr=None)
cast_not_condition = not_condition + '_cast'
node.fluid_code.add_layer(
"cast",
inputs=not_condition,
output=cast_not_condition,
param_attr={'dtype': string(val_x.dtype)})
cast_condition = condition.layer_name + '_cast'
node.fluid_code.add_layer(
"cast",
inputs=condition,
output=cast_condition,
param_attr={'dtype': string(val_x.dtype)})
mul_val_x = val_x.layer_name + '_mul'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={'x': val_x,
'y': cast_condition},
output=mul_val_x,
param_attr=None)
mul_val_y = val_y.layer_name + '_mul'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={'x': val_y,
'y': cast_not_condition},
output=mul_val_y,
param_attr=None)
node.fluid_code.add_layer(
"elementwise_add",
inputs={'x': mul_val_x,
'y': mul_val_y},
output=node,
param_attr=None)
@print_mapping_info
def NonZero(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_x_dim = len(val_x.out_shapes[0])
print(val_x.layer_name, val_x.out_shapes[0])
if val_x_dim == 1:
node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x)
node.fluid_code.add_layer(
"transpose",
inputs=val_x,
output=node,
param_attr={'perm': [1, 0]})
if val_x_dim > 1:
node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x)
node.fluid_code.add_layer(
"split",
inputs=val_x,
output=val_x,
param_attr={'num_or_sections': 1,
'dim': val_x_dim})
node.fluid_code.add_layer("concat", inputs=val_x, output=node)
@print_mapping_info
def Identity(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer("assign", inputs=val_x, output=node)
@print_mapping_info
def Tile(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_repeats = self.graph.get_input_node(node, idx=1, copy=True)
repeats = _const_weight_or_none(val_repeats)
if repeats is None:
repeats = val_repeats.layer_name
elif isinstance(repeats, int):
repeats = [repeats]
attr = {
'expand_times': repeats,
"name": string(node.layer_name),
}
node.fluid_code.add_layer(
"expand", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def MaxPool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
assert node.get_attr(
"dilations") is None, 'only dilations = 0 is supported' # optional
kernel_shape = node.get_attr("kernel_shape")
poolnd = len(kernel_shape)
strides = node.get_attr("strides")
pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0)) # optional
pads = node.get_attr('pads', [0] * (poolnd * 2)) # optional
fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0]
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1],
strides[1])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
attr = {
"pool_size": kernel_shape,
"pool_type": string("max"),
"pool_stride": strides,
"pool_padding": paddings,
"ceil_mode": ceil_mode,
"name": string(node.layer_name),
"exclusive": False
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def _global_pool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
fluid_op = 'pool2d'
pool_type = None
if node.layer.op_type == 'GlobalMaxPool':
pool_type = 'max'
elif node.layer.op_type == 'GlobalAveragePool':
pool_type = 'avg'
attr = {
"pool_type": string(pool_type),
"global_pooling": True,
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def GlobalMaxPool(self, node):
self._global_pool(node)
@print_mapping_info
def GlobalAveragePool(self, node):
self._global_pool(node)
@print_mapping_info
def Conv(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
self.omit_nodes.append(val_w.layer_name)
has_bias = len(node.layer.input) == 3
if has_bias:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr('kernel_shape')
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = val_w.out_shapes[0][0] # OI...
fluid_op = 'conv{}d'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd) # optional
dilations = node.get_attr('dilations', [1] * convnd) # optional
pads = node.get_attr('pads', [0] * (convnd * 2)) # optional
input_shape = val_x.out_shapes[0]
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
pad_h = get_same_padding(input_shape[2], kernel_shape[0],
strides[0])
pad_w = get_same_padding(input_shape[3], kernel_shape[1],
strides[1])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
attr = {
"num_filters": num_out_channels,
"filter_size": kernel_shape,
"stride": strides,
"padding": paddings,
"dilation": dilations,
"groups": num_groups,
'param_attr': string(val_w.layer_name),
"name": string(node.layer_name)
}
if has_bias:
attr["bias_attr"] = string(val_b.layer_name)
else:
attr["bias_attr"] = False
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def ConvTranspose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_b = None
if len(node.layer.input) > 2:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_w.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
out_padding = node.get_attr('output_padding', [0, 0])
kernel_shape = node.get_attr('kernel_shape')
assert kernel_shape, 'kernel_shape not inferred'
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_out_channels = val_w.out_shapes[0][1]
fluid_op = 'conv{}d_transpose'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd)
dilations = node.get_attr('dilations', [1] * convnd)
output_size = node.get_attr('output_shape', [])
pads = node.get_attr('pads', [0] * (convnd * 2))
paddings, var_x = self._pad_if_asymmetric(node, pads, val_x)
output_size = [0, 0]
output_size[0] = (val_x.out_shapes[0][2] - 1
) * strides[0] - 2 * paddings[0] + dilations[0] * (
kernel_shape[0] - 1) + 1 + out_padding[0]
output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1]
attr = {
'num_filters': num_out_channels,
'output_size': output_size or None,
'filter_size': kernel_shape,
'padding': paddings,
'stride': strides,
'dilation': dilations,
'groups': num_groups,
'param_attr': string(val_w.layer_name),
'bias_attr': None if val_b is None else string(val_b.layer_name),
'name': string(node.layer_name),
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def GRU(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_r = self.graph.get_input_node(node, idx=2, copy=True)
val_b = None
val_len = None
val_xh = None
miss_arg_num = 0
num_ipt = len(node.layer.input)
if num_ipt > 3 and node.layer.input[3] != '':
val_b = self.graph.get_input_node(node, idx=3, copy=True)
else:
miss_arg_num += 1
if num_ipt > 4 and node.layer.input[4] != '':
val_len = self.graph.get_input_node(
node, idx=4 - miss_arg_num, copy=True)
else:
miss_arg_num += 1
if num_ipt > 5 and node.layer.input[5] != '':
val_xh = self.graph.get_input_node(
node, idx=5 - miss_arg_num, copy=True)
x_shape = val_x.out_shapes[0]
assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert node.get_attr('clip', None) is None, 'clipping not supported'
hidden_size = node.get_attr('hidden_size', None)
if hidden_size is None:
r_shape = val_r.out_shapes[0]
if r_shape:
hidden_size = r_shape[-1]
if hidden_size is None:
w_shape = var_w.out_shapes[0]
if w_shape:
hidden_size = w_shape[-2] // 3
if hidden_size is None and val_b:
b_shape = val_b.out_shapes[0]
if b_shape:
hidden_size = b_shape[-1] // 6
if hidden_size is None and val_xh:
xh_shape = val_xh.out_shapes[0]
if xh_shape:
hidden_size = xh_shape[-1]
direction = node.get_attr('direction', 'forward')
assert direction != 'bidirectional', 'direction = bidirectional not supported'
activations = node.get_attr('activations', ['Sigmoid', 'Tanh'])
assert len(activations) == 2, 'bidirectional operation not supported'
assert node.get_attr('linear_before_reset',
0) == 0, 'only linear_before_reset = 0 supported'
activations = [s.lower() for s in activations]
gate_activation, candidate_activation = activations
is_reverse = direction == 'reverse'
var_x0 = node.layer_name + '_x0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_x,
output=var_x0,
param_attr={'axes': [1],
'name': string(var_x0)})
var_w0 = node.layer_name + '_w0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_w,
output=var_w0,
param_attr={'axes': [0],
'name': string(var_w0)})
var_fc = node.layer_name + '_fc'
var_mm = (node.layer_name + '_mm') if val_b else var_fc
node.fluid_code.add_layer(
'matmul',
inputs={'x': var_x0,
'y': var_w0},
output=var_mm,
param_attr={
'transpose_x': 0,
'transpose_y': 1,
'name': string(var_mm)
})
var_r0 = node.layer_name + '_r0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_r,
output=var_r0,
param_attr={'axes': [0],
'name': string(var_r0)})
var_r0t = node.layer_name + '_r0t'
node.fluid_code.add_layer(
'transpose',
inputs=var_r0,
output=var_r0t,
param_attr={'perm': [1, 0],
'name': string(var_r0t)})
if val_b:
var_bi = node.layer_name + '_bi'
var_bh = node.layer_name + '_bh'
node.fluid_code.add_layer(
'split',
inputs=val_b,
output=var_bi + ',' + var_bh,
param_attr={
'axis': 1,
'split': [hidden_size * 3, hidden_size * 3],
'name': string(node.layer_name + '.b/split')
})
var_bi0 = node.layer_name + '_bi0'
node.fluid_code.add_layer(
'squeeze',
inputs=var_bi,
output=var_bi0,
param_attr={'axes': [0],
'name': string(var_bi0)})
node.fluid_code.add_layer(
'elmentwise_add',
inputs=[var_mm, var_bi0],
output=var_fc,
param_attr={
'axes': 1,
'name': string(node.layer_name + '.i/bias')
})
if val_xh:
var_xh0 = node.layer_name + '_xh0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_xh,
output=var_xh0,
param_attr={'axes': [1],
'name': string(var_xh0)})
var_y00 = node.layer_name + '_y00'
attr = {
'origin_mode': True,
'h_0': var_xh0 if val_xh else None,
'is_reverse': is_reverse,
'gate_activation': string(gate_activation),
'candidate_activation': string(candidate_activation),
'param_attr': string(var_r0t),
'bias_attr': string(var_bh) if val_b else False,
}
node.fluid_code.add_layer(
'dynamic_gru',
inputs=var_fc + ',' + str(hidden_size),
output=var_y00,
param_attr=attr)
num_opt = len(node.layer.output)
if num_opt > 0 and node.layer.output[0] != '':
node.fluid_code.add_layer(
'unsqueeze',
inputs=var_y00,
output=node.layer.output[0],
param_attr={
'axes': [1, 1],
'name': string(node.layer.output[0])
})
if num_opt > 1 and node.layer.output[1] != '':
node.fluid_code.add_layer(
'unsqueeze',
inputs=var_y00,
output=node.layer.output[1],
param_attr={
'axes': [1, 1],
'name': string(node.layer.output[1])
})
......@@ -2,6 +2,8 @@ import onnx
import numpy as np
from onnx import onnx_pb, helper
MAX_FLOAT = np.asarray([255, 255, 127, 127], dtype=np.uint8).view(np.float32)[0]
def get_old_name(arg, name_prefix=''):
prefix_index = arg.find(name_prefix)
......@@ -747,36 +749,53 @@ def yolo_box(op, block):
outputs_pred_box_x2_clip = [model_name + "@pred_box_x2_clip"]
outputs_pred_box_y2_clip = [model_name + "@pred_box_y2_clip"]
min_const_name = model_name + "@pred_box_min_const"
max_const_name = model_name + "@pred_box_max_const"
min_const = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=[min_const_name],
value=onnx.helper.make_tensor(
name=min_const_name,
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[0.0]))
node_list.append(min_const)
max_const = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=[max_const_name],
value=onnx.helper.make_tensor(
name=max_const_name,
data_type=onnx.TensorProto.FLOAT,
dims=(),
vals=[MAX_FLOAT]))
node_list.append(max_const)
node_pred_box_x1_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_x1_decode,
outputs=outputs_pred_box_x1_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_x1_decode + [min_const_name, max_const_name],
outputs=outputs_pred_box_x1_clip)
node_list.append(node_pred_box_x1_clip)
node_pred_box_y1_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_y1_decode,
outputs=outputs_pred_box_y1_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_y1_decode + [min_const_name, max_const_name],
outputs=outputs_pred_box_y1_clip)
node_list.append(node_pred_box_y1_clip)
node_pred_box_x2_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_x2_sub_w,
outputs=outputs_pred_box_x2_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_x2_sub_w + [min_const_name, max_const_name],
outputs=outputs_pred_box_x2_clip)
node_list.append(node_pred_box_x2_clip)
node_pred_box_y2_clip = onnx.helper.make_node(
'Clip',
inputs=outputs_pred_box_y2_sub_h,
outputs=outputs_pred_box_y2_clip,
min=0.0,
max=float(np.inf))
inputs=outputs_pred_box_y2_sub_h + [min_const_name, max_const_name],
outputs=outputs_pred_box_y2_clip)
node_list.append(node_pred_box_y2_clip)
outputs_pred_box_x2_res = [model_name + "@box_x2_res"]
......
......@@ -13,7 +13,6 @@
# limitations under the License.
# TODO useless node remove
from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
class ONNXOptimizer(object):
......
......@@ -48,8 +48,8 @@
## ONNX
**注:** 部分模型来源于PyTorch,PyTorch的转换可参考[pytorch_to_onnx.md](pytorch_to_onnx.md)
| 模型 | 来源 | operator version|
|-------|--------|---------|
| 模型 | 来源 | operator version|备注|
|-------|--------|---------|---------|
| ResNet18 | [torchvison.model.resnet18](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9|
| ResNet34 | [torchvison.model.resnet34](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9|
| ResNet50 | [torchvison.model.resnet50](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) |9|
......@@ -65,4 +65,6 @@
| mNASNet | [pytorch(personal practice)](https://github.com/rwightman/gen-efficientnet-pytorch) |9|
| EfficientNet | [pytorch(personal practice)](https://github.com/rwightman/gen-efficientnet-pytorch) |9|
| SqueezeNet | [onnx official](https://s3.amazonaws.com/download.onnx/models/opset_9/squeezenet.tar.gz) |9|
|Ultra-Light-Fast-Generic-Face-Detector-1MB| [onnx_model](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/models/onnx)| |
|Ultra-Light-Fast-Generic-Face-Detector-1MB| [onnx_model](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/models/onnx)|9 |
|BERT| [pytorch(huggingface)](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb)|11|转换时需指定input shape,见[文档Q3](FAQ.md)|
|GPT2| [pytorch(huggingface)](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb)|11|转换时需指定input shape,见[文档Q3](FAQ.md)|
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册