diff --git a/docs/inference_model_convertor/op_list.md b/docs/inference_model_convertor/op_list.md index 5e67412a9f2eef13e68f72792b5c6dd25993576d..2cf8ca6186905d6704400dda9d1b4e4e565edb03 100755 --- a/docs/inference_model_convertor/op_list.md +++ b/docs/inference_model_convertor/op_list.md @@ -114,7 +114,7 @@ Aten: | 117 | aten::bitwise\_not | 118 | aten::bitwise\_xor | 119 | aten::bitwise\_and | 120 | aten::silu | | 121 | aten::repeat\_interleave | 122 | aten::maxpool1d | 123 | aten::frobenius\_norm | 124 | aten::format | | 125 | aten::complex | 126 | aten::real | 127 | aten::imag | 128 | aten::fft\_rfftn | -| 129 | aten::fft\_irfftn | 130 | aten::hardsigmoid | 131 | aten::hardswish | | | +| 129 | aten::fft\_irfftn | 130 | aten::hardsigmoid | 131 | aten::hardswish | 132 | aten::linear | Prim: diff --git a/x2paddle/op_mapper/pytorch2paddle/aten.py b/x2paddle/op_mapper/pytorch2paddle/aten.py index b0b3bf1783fcb288249ecca40d0ecff7d16394ba..11fd90e2ac117fc2be61574c3677ac3833ec4614 100755 --- a/x2paddle/op_mapper/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/pytorch2paddle/aten.py @@ -3238,6 +3238,55 @@ def aten_len(mapper, graph, node): return current_inputs, current_outputs +def aten_linear(mapper, graph, node): + """ + TorchScript Code: + %x.6 : Float(1, 128, strides=[128, 1]) = aten::linear(%input.305, %weight.629, %bias.317) + Parameter meaning: + %x.6 (Tensor): output + %input.305 (Tensor): input tensor + %weight.629 (Tensor): weight tensor + %bias.317 (Tensor): bias tensor + """ + scope_name = mapper.normalize_scope_name(node) + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + layer_attrs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # outputs list + current_outputs = [output_name] + # inputs list + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) + layer_inputs["x"] = inputs_name[0] + # transpose weight + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, + scope_name) + layer_attrs_transpose = {} + layer_attrs_transpose["perm"] = [1, 0] + graph.add_layer( + "paddle.transpose", + inputs={"x": inputs_name[1]}, + outputs=[inputs_name[1] + "_transpose"], + scope_name=scope_name, + **layer_attrs_transpose) + layer_inputs["weight"] = inputs_name[1] + "_transpose" + if len(inputs_name) == 3: + mapper._check_input(graph, inputs_node[2], inputs_name[2], + current_outputs, scope_name) + layer_inputs["bias"] = inputs_name[2] + current_inputs = list(layer_inputs.values()) + + graph.add_layer( + "paddle.nn.functional.linear", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name, + **layer_attrs) + return current_inputs, current_outputs + + def aten_log(mapper, graph, node): """ 构构造log的PaddleLayer。 TorchScript示例: diff --git a/x2paddle/optimizer/fusion/trace_fc_fuser.py b/x2paddle/optimizer/fusion/trace_fc_fuser.py index dc752ec0a96028a95d0f69486a77fb3c75554bae..f3cbcd5d3af5f53f9f3ce2de362db1ed96ca7c39 100644 --- a/x2paddle/optimizer/fusion/trace_fc_fuser.py +++ b/x2paddle/optimizer/fusion/trace_fc_fuser.py @@ -113,10 +113,13 @@ class TraceFcFuser(FuseBase): attrs["out_features"] = parameters[weight_name].shape[0] linear_name = "linear{}".format(self.linear_index) self.linear_index += 1 - parameters["{}.weight".format(linear_name)] = parameters[ - weight_name].transpose((1, 0)) - parameters["{}.bias".format(linear_name)] = np.squeeze(parameters[ - bias_name]) + weight_numpy = parameters[weight_name] + parameters["{}.weight".format(linear_name)] = weight_numpy.transpose( + (1, 0)) + self.rm_params.add(weight_name) + bias_numpy = parameters[bias_name] + parameters["{}.bias".format(linear_name)] = np.squeeze(bias_numpy) + self.rm_params.add(bias_name) new_layer = PaddleLayer( layers_id[0], "paddle.nn.Linear", diff --git a/x2paddle/optimizer/pattern_matcher.py b/x2paddle/optimizer/pattern_matcher.py index 1139c0bbb95ab97dc232379c0a49e7291d8983a7..a49c07371f117283a2b8d7f0f3f4e2419a1382a3 100644 --- a/x2paddle/optimizer/pattern_matcher.py +++ b/x2paddle/optimizer/pattern_matcher.py @@ -325,6 +325,7 @@ class FuseBase(object): def __init__(self): self.pattern = PaddleGraph() self.patterns = list() + self.rm_params = set() def operate(self, graph, match_kind="topo"): parameters = graph.parameters @@ -335,6 +336,8 @@ class FuseBase(object): subgraph = get_subgraph("", first_layer_id, graph) self.insert_new_layer(subgraph, parameters, match) self.delete_match(graph) + for param_name in self.rm_params: + parameters.pop(param_name) graph.build() def perform_pattern_matcher(self, graph, match_kind="topo"):