提交 c676f021 编写于 作者: C Channingss

merge elementwise op convert, supported dynamic scale for resize op

上级 84e9e190
...@@ -7,6 +7,7 @@ function abort(){ ...@@ -7,6 +7,7 @@ function abort(){
trap 'abort' 0 trap 'abort' 0
set -e set -e
TRAVIS_BUILD_DIR=${PWD}
cd $TRAVIS_BUILD_DIR cd $TRAVIS_BUILD_DIR
export PATH=/usr/bin:$PATH export PATH=/usr/bin:$PATH
pre-commit install pre-commit install
......
...@@ -57,7 +57,8 @@ class ONNXOpMapper(OpMapper): ...@@ -57,7 +57,8 @@ class ONNXOpMapper(OpMapper):
'Div': 'elementwise_div', 'Div': 'elementwise_div',
'Sub': 'elementwise_sub', 'Sub': 'elementwise_sub',
'Mul': 'elementwise_mul', 'Mul': 'elementwise_mul',
'Pow': 'elementwise_pow',} 'Pow': 'elementwise_pow',
}
def __init__(self, decoder, save_dir): def __init__(self, decoder, save_dir):
super(ONNXOpMapper, self).__init__() super(ONNXOpMapper, self).__init__()
...@@ -160,8 +161,8 @@ class ONNXOpMapper(OpMapper): ...@@ -160,8 +161,8 @@ class ONNXOpMapper(OpMapper):
for opt in layer.output: for opt in layer.output:
if opt in value_infos: if opt in value_infos:
value_info = value_infos[opt] value_info = value_infos[opt]
if len(value_info['shape'] if len(value_info['shape']) == 0 or value_info[
) == 0 or value_info['dtype'] is None or 0 in value_info['shape']: 'dtype'] is None or 0 in value_info['shape']:
if self.is_inference == False: if self.is_inference == False:
self.get_results_of_inference( self.get_results_of_inference(
onnx_model, value_infos, onnx_model, value_infos,
...@@ -258,13 +259,14 @@ class ONNXOpMapper(OpMapper): ...@@ -258,13 +259,14 @@ class ONNXOpMapper(OpMapper):
if child_func_code is not None: if child_func_code is not None:
self.used_custom_layers[op + self.used_custom_layers[op +
'_child_func'] = child_func_code '_child_func'] = child_func_code
def elementwise_map(self, node): def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type] op_type = self.elementwise_ops[node.layer_type]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_x.out_shapes[0])<len(val_y.out_shapes[0]): if len(val_x.out_shapes[0]) < len(val_y.out_shapes[0]):
val_x, val_y = val_y, val_x val_x, val_y = val_y, val_x
val_y_shape = val_y.out_shapes[0] val_y_shape = val_y.out_shapes[0]
...@@ -276,7 +278,6 @@ class ONNXOpMapper(OpMapper): ...@@ -276,7 +278,6 @@ class ONNXOpMapper(OpMapper):
slice_idx += 1 slice_idx += 1
else: else:
break break
attr = {"name": string(node.layer_name)} attr = {"name": string(node.layer_name)}
if slice_idx < len(val_y_shape) and slice_idx > 0: if slice_idx < len(val_y_shape) and slice_idx > 0:
val_y_reshaped = val_y_shape[slice_idx:] val_y_reshaped = val_y_shape[slice_idx:]
...@@ -380,7 +381,9 @@ class ONNXOpMapper(OpMapper): ...@@ -380,7 +381,9 @@ class ONNXOpMapper(OpMapper):
fluid_op = 'resize_{}'.format(mode) fluid_op = 'resize_{}'.format(mode)
if 'linear' in mode: if 'linear' in mode:
print('Warnning: paddle not support resize wiht mode: linear, we use bilinear replace linear') print(
'Warnning: paddle not support resize wiht mode: linear, we use bilinear replace linear'
)
fluid_op = 'resize_bilinear' fluid_op = 'resize_bilinear'
if isinstance(val_scales, ONNXGraphNode): if isinstance(val_scales, ONNXGraphNode):
...@@ -447,7 +450,7 @@ class ONNXOpMapper(OpMapper): ...@@ -447,7 +450,7 @@ class ONNXOpMapper(OpMapper):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
if len(val_x.out_shapes[0])==0: if len(val_x.out_shapes[0]) == 0:
node.fluid_code.add_layer('assign', node.fluid_code.add_layer('assign',
inputs=val_x, inputs=val_x,
output=node, output=node,
...@@ -459,9 +462,6 @@ class ONNXOpMapper(OpMapper): ...@@ -459,9 +462,6 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Shrink(self, node): def Shrink(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
bias = node.get_attr('bias') bias = node.get_attr('bias')
...@@ -845,7 +845,6 @@ class ONNXOpMapper(OpMapper): ...@@ -845,7 +845,6 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Sum(self, node): def Sum(self, node):
val_inps = node.layer.input val_inps = node.layer.input
inputs = { inputs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册