提交 ff5cdb1a 编写于 作者: S SunAhong1993

fix

上级 262229f5
...@@ -1588,9 +1588,6 @@ class OpSet9(): ...@@ -1588,9 +1588,6 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def ConvTranspose(self, node): def ConvTranspose(self, node):
op_name = name_generator("conv", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_b = None val_b = None
...@@ -1604,7 +1601,7 @@ class OpSet9(): ...@@ -1604,7 +1601,7 @@ class OpSet9():
assert 2 <= convnd <= 3, 'only Conv2DTranspose and Conv3DTranspose supported' assert 2 <= convnd <= 3, 'only Conv2DTranspose and Conv3DTranspose supported'
num_in_channels = val_w.out_shapes[0][0] num_in_channels = val_w.out_shapes[0][0]
num_out_channels = val_w.out_shapes[0][1] num_out_channels = val_w.out_shapes[0][1]
paddle_op = 'paddle.nn.Conv{}DTranspose'.format(convnd) paddle_op = 'paddle.nn.functional.conv{}d_transpose'.format(convnd)
num_groups = node.get_attr('group', 1) num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd) strides = node.get_attr('strides', [1] * convnd)
...@@ -1622,37 +1619,21 @@ class OpSet9(): ...@@ -1622,37 +1619,21 @@ class OpSet9():
output_size[1] = (val_x.out_shapes[0][3] - 1 output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * ( ) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1] kernel_shape[1] - 1) + 1 + out_padding[1]
# Conv2DTranspose缺少output_size,只能在forward里头传进output_size
inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name,
"weight": val_w.name}
layer_attrs = { layer_attrs = {
'in_channels': num_in_channels, "stride": strides,
'out_channels': num_out_channels, "dilation": dilations,
'output_size': output_size or None, "padding": paddings,
'kernel_size': kernel_shape, "groups": num_groups,
'padding': paddings, "output_size": node.out_shapes[0][2:]}
'stride': strides, if val_b is not None:
'dilation': dilations, inputs_dict["bias"] = val_b.name
'groups': num_groups, else:
'weight_attr': string(val_w.name), layer_attrs["bias"] = None
'bias_attr': None if val_b is None else string(val_b.name),
}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
paddle_op, kernel="paddle.nn.functional.conv2d_transpose",
inputs={"x": val_x.name}, inputs=inputs_dict,
outputs=layer_outputs, outputs=[node.name],
**layer_attrs) **layer_attrs)
# inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name,
# "weight": val_w.name}
# layer_attrs = {
# "stride": strides,
# "dilation": dilations,
# "padding": paddings,
# "groups": num_groups,
# "output_size": node.out_shapes[0][2:]}
# if val_b is not None:
# inputs_dict["bias"] = val_b.name
# else:
# layer_attrs["bias"] = None
# self.paddle_graph.add_layer(
# kernel="paddle.nn.functional.conv2d_transpose",
# inputs=inputs_dict,
# outputs=[node.name],
# **layer_attrs)
from .opset import OpSet9 from .opset import OpSet9
from .custom_layer import custom_layers
...@@ -17,7 +17,6 @@ from x2paddle.core.graph import GraphNode ...@@ -17,7 +17,6 @@ from x2paddle.core.graph import GraphNode
from x2paddle.core.fluid_code import Layer from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import string from x2paddle.core.util import string
from x2paddle.op_mapper.static.onnx2paddle.opset9.custom_layer import *
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import onnx import onnx
...@@ -508,33 +507,30 @@ class OpSet9(): ...@@ -508,33 +507,30 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def InstanceNormalization(self, node): def InstanceNormalization(self, node):
op_name = name_generator("instanse_norm", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scale = self.graph.get_input_node(node, idx=1, copy=True) val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
epsilon = node.get_attr('epsilon', 1e-5) epsilon = node.get_attr('epsilon', 1e-5)
layer_attrs = { layer_attrs = {
'epsilon': epsilon, 'eps': epsilon,
} }
dim = len(val_x.out_shapes[0]) dim = len(val_x.out_shapes[0])
if dim ==2 : if dim ==2 :
layer_attrs["data_format"] = "NC" layer_attrs["data_format"] = string("NC")
elif dim == 3: elif dim == 3:
layer_attrs["data_format"] = "NCL" layer_attrs["data_format"] = string("NCL")
elif dim == 4: elif dim == 4:
layer_attrs["data_format"] = "NCHW" layer_attrs["data_format"] = string("NCHW")
elif dim == 5: elif dim == 5:
layer_attrs["data_format"] = "NCDHW" layer_attrs["data_format"] = string("NCDHW")
else: else:
raise Exception("The paddle only support 2D, 3D, 4D or 5D input in InstanceNormalization.") raise Exception("The paddle only support 2D, 3D, 4D or 5D input in InstanceNormalization.")
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
paddle_op, "paddle.nn.functional.instance_norm",
inputs={"x": val_x.name, inputs={"x": val_x.name,
"weight": val_scale.name, "weight": val_scale.name,
"bias": val_b.name}, "bias": val_b.name},
outputs=layer_outputs, outputs=[node.name],
**layer_attrs) **layer_attrs)
@print_mapping_info @print_mapping_info
...@@ -1577,7 +1573,7 @@ class OpSet9(): ...@@ -1577,7 +1573,7 @@ class OpSet9():
if val_b is not None: if val_b is not None:
layer_inputs["bias"] = val_b.name layer_inputs["bias"] = val_b.name
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.conv2d_transpose", kernel=paddle_op,
inputs=layer_inputs, inputs=layer_inputs,
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册