提交 4118254f 编写于 作者: C channingss

fix bug of convtranspose

上级 315a3057
...@@ -1231,10 +1231,11 @@ class ONNXOpMapper(OpMapper): ...@@ -1231,10 +1231,11 @@ class ONNXOpMapper(OpMapper):
def ConvTranspose(self, node): def ConvTranspose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = None
if len(node.layer.input)>2:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_w.layer_name) self.omit_nodes.append(val_w.layer_name)
self.omit_nodes.append(val_b.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
...@@ -1272,7 +1273,7 @@ class ONNXOpMapper(OpMapper): ...@@ -1272,7 +1273,7 @@ class ONNXOpMapper(OpMapper):
'dilation': dilations, 'dilation': dilations,
'groups': num_groups, 'groups': num_groups,
'param_attr': string(val_w.layer_name), 'param_attr': string(val_w.layer_name),
'bias_attr': string(val_b.layer_name), 'bias_attr': None if val_b is None else string(val_b.layer_name),
'name': string(node.layer_name), 'name': string(node.layer_name),
} }
node.fluid_code.add_layer(fluid_op, node.fluid_code.add_layer(fluid_op,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册