diff --git a/docs/inference_model_convertor/demo/tensorflow2paddle.ipynb b/docs/inference_model_convertor/demo/tensorflow2paddle.ipynb index e9d4d0797ff86b94cf86a60a4048a36776674e16..30122def364afddb72466a6f614f3bee8ffa075e 100644 --- a/docs/inference_model_convertor/demo/tensorflow2paddle.ipynb +++ b/docs/inference_model_convertor/demo/tensorflow2paddle.ipynb @@ -81,7 +81,7 @@ "source": [ "## 模型迁移\n", "### 1. 获取MobileNetV1的FrozenModel\n", - "由于X2Paddle只支持TensorFlow中FrozenModel的转换,如果为纯checkpoint模型,需要参考参考X2Paddle官方[文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/user_guides/export_tf_model.md),将其转换为FrozenModel,本示例中提供的模型为FrozenModel,所以无需转换。" + "由于X2Paddle只支持TensorFlow中FrozenModel的转换,如果为纯checkpoint模型,需要参考参考X2Paddle官方[文档](https://github.com/PaddlePaddle/X2Paddle/blob/release-1.1/docs/user_guides/export_tf_model.md),将其转换为FrozenModel,本示例中提供的模型为FrozenModel,所以无需转换。" ] }, { @@ -210,4 +210,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/docs/pytorch_project_convertor/API_docs/ops/README.md b/docs/pytorch_project_convertor/API_docs/ops/README.md index 086c1f41a439129da5d3a1b5e95fc447ac1816a6..00fc26dde7ee78ec3dd5674fbd5cc34a4a0f419f 100644 --- a/docs/pytorch_project_convertor/API_docs/ops/README.md +++ b/docs/pytorch_project_convertor/API_docs/ops/README.md @@ -152,7 +152,7 @@ | 147 | [torch.matmul](https://pytorch.org/docs/stable/generated/torch.matmul.html?highlight=matmul#torch.matmul) | [paddle.matmul](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/matmul_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.matmul.md) | | 148 | [torch.mm](https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=mm#torch.mm) | [paddle.matmul](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/matmul_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.mm.md) | | 149 | [torch.mv](https://pytorch.org/docs/stable/generated/torch.mv.html?highlight=mv#torch.mv) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.mv.md) | - +| 150 | [torch.scatter](https://pytorch.org/docs/stable/generated/torch.scatter.html?highlight=scatter#torch.scatter) | [paddle.scatter_nd_add](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/scatter_nd_add_cn.html) | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md) | diff --git a/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md b/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md new file mode 100644 index 0000000000000000000000000000000000000000..2eb4bb9a854109d6ee5d2770afe552abe8f206f3 --- /dev/null +++ b/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md @@ -0,0 +1,75 @@ +## torch.scatter +### [torch.scatter](https://pytorch.org/docs/stable/generated/torch.scatter.html?highlight=scatter#torch.scatter) + +```python +torch.scatter(tensor, + dim, + index, + src) +``` + +### [paddle.scatter_nd_add](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/scatter_nd_add_cn.html) + +```python +paddle.scatter_nd_add(x, + index, + updates, + name=None) +``` + +### 参数差异 +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------ | ------------------------------------------------------ | +| tensor | x | 表示输入Tensor。 | +| dim | - | 表示在哪一个维度scatter,Paddle无此参数 | +| index | index | 输入的索引张量 | +| src | updates | 输入的更新张量 | + + + +### 功能差异 + +#### 使用方式 +因 torch.scatter 与 paddle.scatter_nd_add 差异较大,必须使用 paddle.flatten + paddle.meshgrid + paddle.scatter_nd_add 组合实现,看如下例子 + + +### 代码示例 +``` python +# PyTorch 示例: +src = torch.arange(1, 11).reshape((2, 5)) +# 输出 +# tensor([[ 1, 2, 3, 4, 5], +# [ 6, 7, 8, 9, 10]]) +index = torch.tensor([[0, 1, 2], [0, 1, 4]]) +torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) +# 输出 +# tensor([[1, 2, 3, 0, 0], +# [6, 7, 0, 0, 8], +# [0, 0, 0, 0, 0]]) +``` + +``` python +# PaddlePaddle 组合实现: +x = paddle.zeros([3, 5], dtype="int64") +updates = paddle.arange(1, 11).reshape([2,5]) +# 输出 +# Tensor(shape=[2, 5], dtype=int64, place=CUDAPlace(0), stop_gradient=True, +# [[1 , 2 , 3 , 4 , 5 ], +# [6 , 7 , 8 , 9 , 10]]) +index = paddle.to_tensor([[0, 1, 2], [0, 1, 4]]) +i, j = index.shape +grid_x , grid_y = paddle.meshgrid(paddle.arange(i), paddle.arange(j)) +# 若 PyTorch 的 dim 取 0 +# index = paddle.stack([index.flatten(), grid_y.flatten()], axis=1) +# 若 PyTorch 的 dim 取 1 +index = paddle.stack([grid_x.flatten(), index.flatten()], axis=1) +# PaddlePaddle updates 的 shape 大小必须与 index 对应 +updates_index = paddle.stack([grid_x.flatten(), grid_y.flatten()], axis=1) +updates = paddle.gather_nd(updates, index=updates_index) +paddle.scatter_nd_add(x, index, updates) +# 输出 +# Tensor(shape=[3, 5], dtype=int64, place=CUDAPlace(0), stop_gradient=True, +# [[1, 2, 3, 0, 0], +# [6, 7, 0, 0, 8], +# [0, 0, 0, 0, 0]]) +``` diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py old mode 100644 new mode 100755 index 1841c7ce97b553bfd88343c86c1f5a77ff4595c1..f04748f296ccb539133194ff69287da3cdf4ab46 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -27,22 +27,23 @@ from x2paddle.core.util import * class PaddleLayer(object): def __init__(self, id, kernel, inputs, outputs, scope_name="", **kwargs): - assert isinstance( - inputs, - dict), "parameter 'inputs' for PaddleLayer should be type of dict" + assert isinstance(inputs, ( + dict, list + )), "parameter 'inputs' for PaddleLayer should be type of dict or list" assert isinstance( outputs, list), "parameter 'outputs' for PaddleLayer should be type of list" - for k, v in inputs.items(): - if isinstance(v, (list, tuple)): - for i in v: - assert isinstance( - i, six.string_types + if isinstance(inputs, dict): + for k, v in inputs.items(): + if isinstance(v, (list, tuple)): + for i in v: + assert isinstance( + i, six.string_types + ), "value in inputs should be type of string or list of string" + else: + assert isinstance(v, six.string_types) or isinstance( + v, list ), "value in inputs should be type of string or list of string" - else: - assert isinstance(v, six.string_types) or isinstance( - v, list - ), "value in inputs should be type of string or list of string" for v in outputs: assert isinstance( v, six. @@ -164,11 +165,31 @@ class PaddleGraph(object): self.clear_edges() outputs_from_nodes = dict() for layer_id, layer in self.layers.items(): - for input_key, input_var in layer.inputs.items(): - vs = input_var - if not isinstance(vs, (list, tuple)): - vs = [vs] - for v in vs: + if isinstance(layer.inputs, dict): + for input_key, input_var in layer.inputs.items(): + vs = input_var + if not isinstance(vs, (list, tuple)): + vs = [vs] + for v in vs: + assert v in outputs_from_nodes or ( + inputs is not None and v in list(inputs.values()) + ) or ( + outputs is not None and v in outputs + ), "Couldn't find {} in previous layers, the layers should be make by topological sort".format( + v) + if v in outputs_from_nodes: + in_layer_id = outputs_from_nodes[v] + else: + in_layer_id = -1 + if in_layer_id not in self.edges_out: + self.edges_out[in_layer_id] = list() + self.edges_out[in_layer_id].append(layer_id) + + if layer_id not in self.edges_in: + self.edges_in[layer_id] = list() + self.edges_in[layer_id].append(in_layer_id) + else: + for v in layer.inputs: assert v in outputs_from_nodes or ( inputs is not None and v in list(inputs.values()) ) or ( @@ -186,6 +207,7 @@ class PaddleGraph(object): if layer_id not in self.edges_in: self.edges_in[layer_id] = list() self.edges_in[layer_id].append(in_layer_id) + for output in layer.outputs: outputs_from_nodes[output] = layer_id @@ -496,16 +518,20 @@ class PaddleGraph(object): else: line = ','.join(layer.outputs) line += " = {}(".format(layer.kernel) - for k, v in layer.inputs.items(): - if isinstance(v, list): - line += "{}=[{}], ".format(k, ", ".join(v)) - elif isinstance(v, tuple): - line += "{}=({}), ".format(k, ", ".join(v)) - else: - if k == "args": - line += v + if isinstance(layer.inputs, dict): + for k, v in layer.inputs.items(): + if isinstance(v, list): + line += "{}=[{}], ".format(k, ", ".join(v)) + elif isinstance(v, tuple): + line += "{}=({}), ".format(k, ", ".join(v)) else: - line += "{}={}, ".format(k, v) + if k == "args": + line += v + else: + line += "{}={}, ".format(k, v) + else: + line += "{}".format(", ".join(layer.inputs)) + for k, v in layer.attrs.items(): line += "{}={}, ".format(k, v) line = line.strip(", ") @@ -532,9 +558,9 @@ class PaddleGraph(object): paddle.save(self.parameters, save_path) def dygraph2static(self, save_dir, input_shapes=[], input_types=[]): - sepc_list = list() + spec_list = list() for i, name in enumerate(self.inputs): - sepc_list.append( + spec_list.append( paddle.static.InputSpec( shape=input_shapes[i], name=name, dtype=input_types[i])) path = osp.abspath(save_dir) @@ -548,7 +574,7 @@ class PaddleGraph(object): else: model.set_dict(restore) model.eval() - static_model = paddle.jit.to_static(model, input_spec=sepc_list) + static_model = paddle.jit.to_static(model, input_spec=spec_list) try: paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model")) diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index db361db2838ea79e9ad84d345a9260b6cac36503..db4ed90ba33d1aa819c0973e76a3df15f13688dc 100755 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -583,6 +583,9 @@ class ONNXDecoder(object): item.name = self.make_variable_name(item.name) for node in graph.node: node.name = node.output[0] + if ":" in node.name and len( + node.output) > 1 and node.op_type != "LSTM": + node.name = node.name.split(':')[0] node.name = self.make_variable_name(node.name) for i in range(len(node.input)): if node.input[i] == '': diff --git a/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py b/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py index e05f60c317496fd49f798ec3acbf60f5e0d592ef..c120945b1f6c1ac0c8c38a65091bf085b66031e5 100644 --- a/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py +++ b/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py @@ -966,11 +966,12 @@ class CaffeOpMapper(): inputs={"x": input.name}, outputs=[node.layer_name], **layer_attrs) - self.paddle_graph.add_layer( - "paddle.pow", - inputs={"x": node.layer_name}, - outputs=[node.layer_name], - exponent=params.power) + if params.power != 1: + self.paddle_graph.add_layer( + "paddle.pow", + inputs={"x": node.layer_name, + "y": params.power}, + outputs=[node.layer_name]) def Reduction(self, node): assert len( diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 55ec32ee3fa841c4ee8e152ed284052f33bdbb41..6fb18a03ecd1f224d6ac86f2036c3900f20365c2 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -62,6 +62,7 @@ def _rename_or_remove_weight(weights, if origin_name not in weights: raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) if is_remove: + # TODO There may be problems when the same data is used as an argument to multiple OPs. # remove weight data = weights.pop(origin_name) else: @@ -169,6 +170,8 @@ class OpSet9(): 'Floor': ['paddle.floor'], 'Abs': ['paddle.abs'], 'Erf': ['paddle.erf'], + 'Sin': ['paddle.sin'], + 'Cos': ['paddle.cos'], } def __init__(self, decoder, paddle_graph): @@ -248,6 +251,7 @@ class OpSet9(): node = parameter dtype = node.dtype shape = node.out_shapes[0] + if hasattr(node.weight, "shape") and len(node.weight.shape) == 0: self.paddle_graph.add_layer( "paddle.full", @@ -302,6 +306,7 @@ class OpSet9(): elif len(node.layer.input) == 4: # opset 11 val_sizes = self.graph.get_input_node(node, idx=3, copy=True) + size_values = _const_weight_or_none(val_sizes) val_x_shape = val_x.out_shapes[0] if len(val_x_shape) == 3: var_n, var_hw = val_sizes.name + '_n', val_sizes.name + '_hw' @@ -347,23 +352,26 @@ class OpSet9(): outputs=[node.name], axis=0) else: - var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw' - self.paddle_graph.add_layer( - 'paddle.split', - inputs={"x": val_sizes.name}, - outputs=[var_nc, var_hw], - num_or_sections=[2, 2], - axis=0) - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": var_hw}, - outputs=[var_hw], - dtype=string('int32')) - inputs['size'] = var_hw - attrs = { + if size_values is not None: + attrs["size"] = [size_values[2], size_values[3]] + else: + var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw' + self.paddle_graph.add_layer( + 'paddle.split', + inputs={"x": val_sizes.name}, + outputs=[var_nc, var_hw], + num_or_sections=[2, 2], + axis=0) + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": var_hw}, + outputs=[var_hw], + dtype=string('int32')) + inputs['size'] = var_hw + attrs.update({ "align_corners": False, "mode": string(node.get_attr('mode', 'nearest')) - } + }) mode = node.get_attr('mode', 'nearest') if mode == "linear": attrs["mode"] = string("bilinear") @@ -381,15 +389,18 @@ class OpSet9(): **attrs) return elif node.layer_type == 'Upsample': - val_scales = self.graph.get_input_node(node, idx=1, copy=True) - self.paddle_graph.add_layer( - "paddle.slice", - inputs={"input": val_scales.name}, - outputs=[val_scales.name], - axes=[0], - starts=[2], - ends=[4]) - inputs['scale_factor'] = val_scales.name + if len(node.layer.input) == 2: + val_scales = self.graph.get_input_node(node, idx=1, copy=True) + self.paddle_graph.add_layer( + "paddle.slice", + inputs={"input": val_scales.name}, + outputs=[val_scales.name], + axes=[0], + starts=[2], + ends=[4]) + inputs['scale_factor'] = val_scales.name + else: + val_scales = node.get_attr('scales')[2:] mode = node.get_attr('mode', 'nearest') attrs.update({ @@ -397,6 +408,8 @@ class OpSet9(): "mode": string(mode), "align_mode": 1 }) + if len(node.layer.input) == 1: + attrs["scale_factor"] = val_scales val_x_shape = val_x.out_shapes[0] if mode == "linear" and len(val_x_shape) == 4: attrs["mode"] = string("bilinear") @@ -676,8 +689,7 @@ class OpSet9(): axes = node.get_attr('axes') if axes is None: axes = self.graph.get_input_node(node, idx=1, copy=True) - - if len(val_x.out_shapes[0]) == 0: + if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: if node.name: self.paddle_graph.add_layer( 'paddle.reshape', @@ -798,11 +810,19 @@ class OpSet9(): val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_x_dtype = val_x.dtype name_ones = node.name + '_ones' - attr_ones = { - 'shape': val_shape.name, - 'dtype': string(val_x_dtype), - 'fill_value': 1 - } + shape_values = _const_weight_or_none(val_shape) + if shape_values is None: + attr_ones = { + 'shape': val_shape.name, + 'dtype': string(val_x_dtype), + 'fill_value': 1 + } + else: + attr_ones = { + 'shape': shape_values.tolist(), + 'dtype': string(val_x_dtype), + 'fill_value': 1 + } self.paddle_graph.add_layer( 'paddle.full', inputs={}, outputs=[name_ones], **attr_ones) inputs_dict = {'x': name_ones, 'y': val_x.name} @@ -834,6 +854,11 @@ class OpSet9(): outputs=[node.name]) elif len(val_x.out_shapes[0]) > 1: if len(indices_shape) == 0: + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": indices.name}, + outputs=[indices.name], + shape=[-1, ]) gather_ = node.name + '_1' self.paddle_graph.add_layer( 'paddle.gather', @@ -1136,6 +1161,10 @@ class OpSet9(): starts = node.get_attr('starts') ends = node.get_attr('ends') axes = node.get_attr('axes') + output_shape = val_x.out_shapes[0] + + if axes is None: + axes = [i for i in range(len(starts))] for idx in range(len(ends)): if ends[idx] > 2**31 - 1: ends[idx] = 2**31 - 1 @@ -1176,7 +1205,6 @@ class OpSet9(): @print_mapping_info def GatherND(self, node): - print(len(node.inputs), node.inputs) val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) self.paddle_graph.add_layer( @@ -1342,7 +1370,6 @@ class OpSet9(): @print_mapping_info def GatherND(self, node): - print(len(node.inputs), node.inputs) val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) self.paddle_graph.add_layer( @@ -1366,8 +1393,6 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) paddle_op = 'split' split = node.get_attr('split') - if split is None: - split = len(node.outputs) axis = node.get_attr('axis', 0) if split is None: split_num = len(node.layer.output) @@ -1972,6 +1997,143 @@ class OpSet9(): outputs=layer_outputs, output_size=output_shape[2:]) + @print_mapping_info + def Neg(self, node): + import paddle + val_x = self.graph.get_input_node(node, idx=0, copy=True) + v0, v1, v2 = paddle.__version__.split('.') + if int(v0) >= 2 and int(v1) >= 2: + self.paddle_graph.add_layer( + "paddle.neg", inputs={'x': val_x.name}, outputs=[node.name]) + else: + val_y = node.name + "_y" + dtype = np.dtype(val_x.dtype) + self.paddle_graph.add_layer( + "paddle.full", + inputs={}, + outputs=[val_y], + dtype=string(dtype), + shape=[1], + fill_value=-1) + self.paddle_graph.add_layer( + "paddle.multiply", + inputs={'x': val_x.name, + 'y': val_y}, + outputs=[node.name]) + + @print_mapping_info + def SpaceToDepth(self, node): + val_x = self.graph.get_input_node(node, idx=0, copy=True) + blocksize = node.get_attr('blocksize') + val_x_shape = val_x.out_shapes[0] + b, c, h, w = val_x_shape + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=[b, c, h // blocksize, blocksize, w // blocksize, blocksize]) + self.paddle_graph.add_layer( + 'paddle.transpose', + inputs={"x": node.name}, + outputs=[node.name], + perm=[0, 3, 5, 1, 2, 4]) + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": node.name}, + outputs=[node.name], + shape=[b, c * (blocksize**2), h // blocksize, w // blocksize]) + + @print_mapping_info + def GatherElements(self, node): + val_x = self.graph.get_input_node(node, idx=0, copy=True) + indices = self.graph.get_input_node(node, idx=1, copy=True) + axis = node.get_attr('axis') + val_x_shape = val_x.out_shapes[0] + indices_shape = indices.out_shapes[0] + axis = axis if axis >= 0 else axis + len(val_x_shape) + if axis == 0: + axis_perm = [i for i in range(len(val_x_shape))] + data_swaped = val_x.name + index_swaped = indices.name + else: + axis_perm = [i for i in range(len(val_x_shape))] + axis_perm[axis] = 0 + axis_perm[0] = axis + data_swaped = val_x.name + "_transpose" + self.paddle_graph.add_layer( + "paddle.transpose", + inputs={'x': val_x.name}, + perm=axis_perm, + outputs=[data_swaped]) + index_swaped = indices.name + "_transpose" + self.paddle_graph.add_layer( + "paddle.transpose", + inputs={'x': indices.name}, + perm=axis_perm, + outputs=[index_swaped]) + temp = indices_shape[0] + indices_shape[0] = indices_shape[axis] + indices_shape[axis] = temp + + idx_tensors_per_axis_pre = [ + indices_shape[i] for i in range(len(indices_shape)) + ] + name_list = list() + for i in range(len(idx_tensors_per_axis_pre)): + tensor_name = val_x.name + "_meshgrid_" + str(i) + self.paddle_graph.add_layer( + kernel="paddle.linspace", + inputs={}, + outputs=[tensor_name], + start=0, + stop=idx_tensors_per_axis_pre[i] - 1, + num=idx_tensors_per_axis_pre[i]) + name_list.append(tensor_name) + + self.paddle_graph.add_layer( + "paddle.meshgrid", inputs=name_list, outputs=name_list) + + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": index_swaped}, + outputs=[index_swaped], + dtype=string("float32")) + import copy + copy_name_list = copy.copy(name_list) + copy_name_list[0] = index_swaped + new_name_list = list() + for i in range(len(copy_name_list)): + unsqueeze_name = copy_name_list[i] + "_unsqueeze" + self.paddle_graph.add_layer( + "paddle.unsqueeze", + inputs={"x": copy_name_list[i]}, + axis=-1, + outputs=[unsqueeze_name]) + new_name_list.append(unsqueeze_name) + concat_name = val_x.name + "_concated_layer" + self.paddle_graph.add_layer( + "paddle.concat", + inputs={'x': new_name_list}, + axis=-1, + outputs=[concat_name]) + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": concat_name}, + outputs=[concat_name], + dtype=string("int32")) + gather_nd_name = "gather_nd_layer" + self.paddle_graph.add_layer( + "paddle.gather_nd", + inputs={'x': data_swaped, + "index": concat_name}, + outputs=[gather_nd_name]) + + self.paddle_graph.add_layer( + "paddle.transpose", + inputs={'x': gather_nd_name}, + perm=axis_perm, + outputs=[node.name]) + @print_mapping_info def GlobalAveragePool(self, node): op_name = name_generator("pool", self.nn_name2id) @@ -2126,14 +2288,35 @@ class OpSet9(): paddings, var_x = self._pad_if_asymmetric(node, pads, val_x) - output_size = [0, 0] + if len(output_size) != 0: + paddings = [0] * 4 + total_paddings = list() + total_paddings.append((val_x.out_shapes[0][2] - 1) * strides[ + 0] + dilations[0] * (kernel_shape[0] - 1) + 1 + out_padding[0] - + output_size[0]) + total_paddings.append((val_x.out_shapes[0][3] - 1) * strides[ + 1] + dilations[1] * (kernel_shape[1] - 1) + 1 + out_padding[1] - + output_size[1]) + if auto_pad == "SAME_UPPER": + for i in range(len(total_paddings)): + paddings[2 * i] = total_paddings[0] - total_paddings[0] // 2 + paddings[2 * i + 1] = total_paddings[0] // 2 + else: + for i in range(len(total_paddings)): + paddings[2 * i] = total_paddings[0] // 2 + paddings[2 * i + 1] = total_paddings[0] - total_paddings[ + 0] // 2 + else: + output_size = [0, 0] - output_size[0] = (val_x.out_shapes[0][2] - 1 - ) * strides[0] - 2 * paddings[0] + dilations[0] * ( - kernel_shape[0] - 1) + 1 + out_padding[0] - output_size[1] = (val_x.out_shapes[0][3] - 1 - ) * strides[1] - 2 * paddings[1] + dilations[1] * ( - kernel_shape[1] - 1) + 1 + out_padding[1] + output_size[0] = ( + val_x.out_shapes[0][2] - 1 + ) * strides[0] - 2 * paddings[0] + dilations[0] * ( + kernel_shape[0] - 1) + 1 + out_padding[0] + output_size[1] = ( + val_x.out_shapes[0][3] - 1 + ) * strides[1] - 2 * paddings[1] + dilations[1] * ( + kernel_shape[1] - 1) + 1 + out_padding[1] # Conv2DTranspose缺少output_size,只能在forward里头传进output_size inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name} @@ -2176,6 +2359,8 @@ class OpSet9(): if val_b is not None: _rename_or_remove_weight(self.weights, val_b.name, op_name + '.bias') + else: + layer_attrs["bias_attr"] = False self.paddle_graph.add_layer( kernel=paddle_op, inputs=inputs_dict, diff --git a/x2paddle/op_mapper/pytorch2paddle/aten.py b/x2paddle/op_mapper/pytorch2paddle/aten.py index 683dbad39591a8dc1f3658e46be730b3f0da2b15..b11a41440cf30af8869868af1ff121b3a2166560 100755 --- a/x2paddle/op_mapper/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/pytorch2paddle/aten.py @@ -1315,8 +1315,10 @@ def aten__convolution(mapper, graph, node): weights = mapper.pytorch_params[inputs_name[1]] if len(weights.shape) == 3: op_name = name_generator("conv1d", mapper.nn_name2id) - else: + elif len(weights.shape) == 4: op_name = name_generator("conv2d", mapper.nn_name2id) + else: + op_name = name_generator("conv3d", mapper.nn_name2id) output_name = mapper._get_outputs_name(node)[0] layer_outputs = [op_name, output_name] layer_inputs = {} @@ -1364,7 +1366,22 @@ def aten__convolution(mapper, graph, node): else: layer_attrs['in_channels'] = weights.shape[1] * mapper.attrs[ inputs_name[8]] - if len(weights.shape) == 4: + if len(weights.shape) == 3: + if mapper.attrs[inputs_name[6]]: + graph.add_layer( + "paddle.nn.Conv1DTranspose", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name, + **layer_attrs) + else: + graph.add_layer( + "paddle.nn.Conv1D", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name, + **layer_attrs) + elif len(weights.shape) == 4: if mapper.attrs[inputs_name[6]]: graph.add_layer( "paddle.nn.Conv2DTranspose", @@ -1382,14 +1399,14 @@ def aten__convolution(mapper, graph, node): else: if mapper.attrs[inputs_name[6]]: graph.add_layer( - "paddle.nn.Conv1DTranspose", + "paddle.nn.Conv3DTranspose", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) else: graph.add_layer( - "paddle.nn.Conv1D", + "paddle.nn.Conv3D", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name,