未验证 提交 f05a1fe9 编写于 作者: S SunAhong1993 提交者: GitHub

Merge pull request #19 from SunAhong1993/develop

add
...@@ -205,7 +205,7 @@ class ONNXGraph(Graph): ...@@ -205,7 +205,7 @@ class ONNXGraph(Graph):
shape = raw_input( shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
) )
except: except NameError:
shape = input( shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
) )
...@@ -302,7 +302,18 @@ class ONNXGraph(Graph): ...@@ -302,7 +302,18 @@ class ONNXGraph(Graph):
if opt == in_node: if opt == in_node:
self.connect(nd.name, layer_name) self.connect(nd.name, layer_name)
flag = 1 flag = 1
node.which_child[nd.name] = idx if nd.name in node.which_child:
for n_i, n_ipt in enumerate(node.inputs):
if first_i == n_i:
continue
if n_ipt == nd.name:
new_nd_name = "{}/{}".format(nd.name, n_i)
if new_nd_name not in node.which_child:
node.which_child[new_nd_name] = idx
break
else:
first_i = node.inputs.index(nd.name)
node.which_child[nd.name] = idx
self.node_map[nd.name].index = 0 self.node_map[nd.name].index = 0
break break
if flag == 1: if flag == 1:
...@@ -318,11 +329,15 @@ class ONNXGraph(Graph): ...@@ -318,11 +329,15 @@ class ONNXGraph(Graph):
if len(node.which_child) == 0: if len(node.which_child) == 0:
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy) ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
return ipt_node return ipt_node
else: else:
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy) ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
if ipt_node.layer_name in node.which_child: new_ipt_name = "{}/{}".format(ipt_node.layer_name, idx)
ipt_node.index = node.which_child[ipt_node.layer_name] if new_ipt_name in node.which_child:
ipt_node.index = node.which_child[new_ipt_name]
else:
if ipt_node.layer_name in node.which_child:
ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node return ipt_node
......
...@@ -250,15 +250,22 @@ class OpSet9(): ...@@ -250,15 +250,22 @@ class OpSet9():
def _interpolate(self, node): def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'x': val_x.name} inputs = {'x': val_x.name}
attrs = dict()
if node.layer_type == 'Resize': if node.layer_type == 'Resize':
if len(node.layer.input) == 2: if len(node.layer.input) == 2:
# opset 10 # opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale_factor'] = val_scales.name # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) val_scales = self.graph.get_input_node(node, idx=2, copy=True)
inputs['scale_factor'] = val_scales.name # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 4: elif len(node.layer.input) == 4:
# opset 11 # opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True) val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
...@@ -281,7 +288,7 @@ class OpSet9(): ...@@ -281,7 +288,7 @@ class OpSet9():
ipt = inputs.pop("x") ipt = inputs.pop("x")
inputs["input"] = ipt inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False} attrs.update({"align_corners": False})
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="fluid.layers.resize_nearest", kernel="fluid.layers.resize_nearest",
inputs=inputs, inputs=inputs,
...@@ -290,12 +297,12 @@ class OpSet9(): ...@@ -290,12 +297,12 @@ class OpSet9():
return return
elif node.layer_type == 'Upsample': elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales inputs['scale_factor'] = val_scales
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False, attrs.update({"align_corners": False,
"mode": string(mode), "mode": string(mode),
"align_mode": 1} "align_mode": 1})
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate", kernel="paddle.nn.functional.interpolate",
inputs=inputs, inputs=inputs,
...@@ -926,16 +933,17 @@ class OpSet9(): ...@@ -926,16 +933,17 @@ class OpSet9():
'max': max_value, 'max': max_value,
'min': min_value, 'min': min_value,
} }
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.clip', 'paddle.clip',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
else: else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True) min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True) max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt)
min_value = _const_weight_or_none(min_ipt) min_value = _const_weight_or_none(min_ipt)
max_value = _const_weight_or_none(max_ipt)
if max_value.shape == (1, ): if max_value.shape == (1, ):
max_value = max_value[0] max_value = max_value[0]
if min_value.shape == (1, ): if min_value.shape == (1, ):
...@@ -1637,3 +1645,16 @@ class OpSet9(): ...@@ -1637,3 +1645,16 @@ class OpSet9():
inputs=inputs_dict, inputs=inputs_dict,
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
@print_mapping_info
def ArgMax(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis')
keepdims = False if node.get_attr('keepdims') == 0 else True
layer_attrs = {'axis': axis,
'keepdim': keepdims}
self.paddle_graph.add_layer(
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
...@@ -240,15 +240,22 @@ class OpSet9(): ...@@ -240,15 +240,22 @@ class OpSet9():
def _interpolate(self, node): def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'x': val_x.name} inputs = {'x': val_x.name}
attrs = dict()
if node.layer_type == 'Resize': if node.layer_type == 'Resize':
if len(node.layer.input) == 2: if len(node.layer.input) == 2:
# opset 10 # opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale_factor'] = val_scales.name # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) val_scales = self.graph.get_input_node(node, idx=2, copy=True)
inputs['scale_factor'] = val_scales.name # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 4: elif len(node.layer.input) == 4:
# opset 11 # opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True) val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
...@@ -271,7 +278,7 @@ class OpSet9(): ...@@ -271,7 +278,7 @@ class OpSet9():
ipt = inputs.pop("x") ipt = inputs.pop("x")
inputs["input"] = ipt inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False} attrs.update({"align_corners": False})
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="fluid.layers.resize_nearest", kernel="fluid.layers.resize_nearest",
inputs=inputs, inputs=inputs,
...@@ -283,9 +290,9 @@ class OpSet9(): ...@@ -283,9 +290,9 @@ class OpSet9():
inputs['scale'] = val_scales inputs['scale'] = val_scales
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False, attrs.update({"align_corners": False,
"mode": string(mode), "mode": string(mode),
"align_mode": 1} "align_mode": 1})
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate", kernel="paddle.nn.functional.interpolate",
inputs=inputs, inputs=inputs,
...@@ -917,10 +924,10 @@ class OpSet9(): ...@@ -917,10 +924,10 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
else: else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True) min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True) max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt)
min_value = _const_weight_or_none(min_ipt) min_value = _const_weight_or_none(min_ipt)
max_value = _const_weight_or_none(max_ipt)
if max_value.shape == (1, ): if max_value.shape == (1, ):
max_value = max_value[0] max_value = max_value[0]
if min_value.shape == (1, ): if min_value.shape == (1, ):
...@@ -1576,4 +1583,17 @@ class OpSet9(): ...@@ -1576,4 +1583,17 @@ class OpSet9():
kernel=paddle_op, kernel=paddle_op,
inputs=layer_inputs, inputs=layer_inputs,
outputs=[node.name], outputs=[node.name],
**layer_attrs)
@print_mapping_info
def ArgMax(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis')
keepdims = False if node.get_attr('keepdims') == 0 else True
layer_attrs = {'axis': axis,
'keepdim': keepdims}
self.paddle_graph.add_layer(
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs) **layer_attrs)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册