未验证 提交 2bc0a830 编写于 作者: Y yeliang2258 提交者: GitHub

Enhanced conv and pad functions in onnx opset9 (#707)

* Enhanced conv and pad functions in onnx opset9

* update code style

* code style

* fix style
上级 96ad07bb
...@@ -60,7 +60,7 @@ def _rename_or_remove_weight(weights, ...@@ -60,7 +60,7 @@ def _rename_or_remove_weight(weights,
None None
''' '''
if origin_name not in weights: if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights)) raise KeyError('{} not a key in {}'.format(origin_name, weights.keys()))
if is_remove: if is_remove:
# remove weight # remove weight
data = weights.pop(origin_name) data = weights.pop(origin_name)
...@@ -585,7 +585,21 @@ class OpSet9(): ...@@ -585,7 +585,21 @@ class OpSet9():
layer_attrs["pad"] = paddings layer_attrs["pad"] = paddings
paddle_op = "custom_layer:PadAllDim4WithOneInput" paddle_op = "custom_layer:PadAllDim4WithOneInput"
else: else:
raise Exception("The padding value {} is wrong!".format(pads)) pad_data = node.get_attr('pads')
pad_data1 = pad_data[0::2]
pad_data_all = []
for i in range(len(pad_data1)):
pad_data_all.append(pad_data[i])
pad_data_all.append(pad_data[len(pad_data1) + i])
layer_attrs["pad"] = pad_data_all
self.paddle_graph.add_layer(
'paddle.nn.functional.pad',
inputs={'x': val_x.name},
outputs=layer_outputs[1:],
**layer_attrs)
return
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
paddle_op, paddle_op,
inputs={'x': val_x.name}, inputs={'x': val_x.name},
...@@ -1982,11 +1996,17 @@ class OpSet9(): ...@@ -1982,11 +1996,17 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def Conv(self, node): def Conv(self, node):
op_name = name_generator("conv", self.nn_name2id)
output_name = node.name output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
if val_w.name in self.weights.keys():
op_name = name_generator("conv", self.nn_name2id)
else:
op_name = output_name
layer_outputs = [op_name, output_name]
has_bias = len(node.layer.input) == 3 has_bias = len(node.layer.input) == 3
if has_bias: if has_bias:
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
...@@ -2015,6 +2035,25 @@ class OpSet9(): ...@@ -2015,6 +2035,25 @@ class OpSet9():
paddings = pad_h + pad_w paddings = pad_h + pad_w
layer_inputs = {'x': val_x if isinstance(val_x, str) else val_x.name} layer_inputs = {'x': val_x if isinstance(val_x, str) else val_x.name}
if val_w.name not in self.weights.keys():
layer_attrs = {
"stride": strides,
"padding": paddings,
"dilation": dilations,
"groups": num_groups,
}
layer_inputs['weight'] = val_w.name
if has_bias:
layer_inputs['bias'] = val_b.name
paddle_op = 'paddle.nn.functional.conv{}d'.format(convnd)
self.paddle_graph.add_layer(
paddle_op,
inputs=layer_inputs,
outputs=[node.name],
**layer_attrs)
return
layer_attrs = { layer_attrs = {
"in_channels": num_in_channels * num_groups, "in_channels": num_in_channels * num_groups,
"out_channels": num_out_channels, "out_channels": num_out_channels,
...@@ -2055,11 +2094,17 @@ class OpSet9(): ...@@ -2055,11 +2094,17 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def ConvTranspose(self, node): def ConvTranspose(self, node):
op_name = name_generator("conv_trans", self.nn_name2id)
output_name = node.name output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
if val_w.name in self.weights.keys():
op_name = name_generator("conv_trans", self.nn_name2id)
else:
op_name = output_name
layer_outputs = [op_name, output_name]
val_b = None val_b = None
if len(node.layer.input) > 2: if len(node.layer.input) > 2:
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
...@@ -2092,6 +2137,27 @@ class OpSet9(): ...@@ -2092,6 +2137,27 @@ class OpSet9():
# Conv2DTranspose缺少output_size,只能在forward里头传进output_size # Conv2DTranspose缺少output_size,只能在forward里头传进output_size
inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name} inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name}
if val_w.name not in self.weights.keys():
layer_attrs = {
"stride": strides,
"dilation": dilations,
"padding": paddings,
"groups": num_groups,
"output_padding": out_padding
}
paddle_op = 'paddle.nn.functional.conv{}d_transpose'.format(convnd)
inputs_dict['weight'] = val_w.name
if len(node.layer.input) > 2:
inputs_dict['bias'] = val_b.name
self.paddle_graph.add_layer(
paddle_op,
inputs=inputs_dict,
outputs=[node.name],
**layer_attrs)
return
layer_attrs = { layer_attrs = {
"in_channels": num_in_channels, "in_channels": num_in_channels,
"out_channels": num_out_channels * num_groups, "out_channels": num_out_channels * num_groups,
......
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册