未验证 提交 aab6c75d 编写于 作者: J Jason 提交者: GitHub

Merge pull request #769 from wjj19950828/add_linear

Add aten::linear
...@@ -114,7 +114,7 @@ Aten: ...@@ -114,7 +114,7 @@ Aten:
| 117 | aten::bitwise\_not | 118 | aten::bitwise\_xor | 119 | aten::bitwise\_and | 120 | aten::silu | | 117 | aten::bitwise\_not | 118 | aten::bitwise\_xor | 119 | aten::bitwise\_and | 120 | aten::silu |
| 121 | aten::repeat\_interleave | 122 | aten::maxpool1d | 123 | aten::frobenius\_norm | 124 | aten::format | | 121 | aten::repeat\_interleave | 122 | aten::maxpool1d | 123 | aten::frobenius\_norm | 124 | aten::format |
| 125 | aten::complex | 126 | aten::real | 127 | aten::imag | 128 | aten::fft\_rfftn | | 125 | aten::complex | 126 | aten::real | 127 | aten::imag | 128 | aten::fft\_rfftn |
| 129 | aten::fft\_irfftn | 130 | aten::hardsigmoid | 131 | aten::hardswish | | | | 129 | aten::fft\_irfftn | 130 | aten::hardsigmoid | 131 | aten::hardswish | 132 | aten::linear |
Prim: Prim:
......
...@@ -3238,6 +3238,55 @@ def aten_len(mapper, graph, node): ...@@ -3238,6 +3238,55 @@ def aten_len(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_linear(mapper, graph, node):
"""
TorchScript Code:
%x.6 : Float(1, 128, strides=[128, 1]) = aten::linear(%input.305, %weight.629, %bias.317)
Parameter meaning:
%x.6 (Tensor): output
%input.305 (Tensor): input tensor
%weight.629 (Tensor): weight tensor
%bias.317 (Tensor): bias tensor
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# outputs list
current_outputs = [output_name]
# inputs list
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0]
# transpose weight
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_attrs_transpose = {}
layer_attrs_transpose["perm"] = [1, 0]
graph.add_layer(
"paddle.transpose",
inputs={"x": inputs_name[1]},
outputs=[inputs_name[1] + "_transpose"],
scope_name=scope_name,
**layer_attrs_transpose)
layer_inputs["weight"] = inputs_name[1] + "_transpose"
if len(inputs_name) == 3:
mapper._check_input(graph, inputs_node[2], inputs_name[2],
current_outputs, scope_name)
layer_inputs["bias"] = inputs_name[2]
current_inputs = list(layer_inputs.values())
graph.add_layer(
"paddle.nn.functional.linear",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_log(mapper, graph, node): def aten_log(mapper, graph, node):
""" 构构造log的PaddleLayer。 """ 构构造log的PaddleLayer。
TorchScript示例: TorchScript示例:
......
...@@ -113,10 +113,13 @@ class TraceFcFuser(FuseBase): ...@@ -113,10 +113,13 @@ class TraceFcFuser(FuseBase):
attrs["out_features"] = parameters[weight_name].shape[0] attrs["out_features"] = parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index) linear_name = "linear{}".format(self.linear_index)
self.linear_index += 1 self.linear_index += 1
parameters["{}.weight".format(linear_name)] = parameters[ weight_numpy = parameters[weight_name]
weight_name].transpose((1, 0)) parameters["{}.weight".format(linear_name)] = weight_numpy.transpose(
parameters["{}.bias".format(linear_name)] = np.squeeze(parameters[ (1, 0))
bias_name]) self.rm_params.add(weight_name)
bias_numpy = parameters[bias_name]
parameters["{}.bias".format(linear_name)] = np.squeeze(bias_numpy)
self.rm_params.add(bias_name)
new_layer = PaddleLayer( new_layer = PaddleLayer(
layers_id[0], layers_id[0],
"paddle.nn.Linear", "paddle.nn.Linear",
......
...@@ -325,6 +325,7 @@ class FuseBase(object): ...@@ -325,6 +325,7 @@ class FuseBase(object):
def __init__(self): def __init__(self):
self.pattern = PaddleGraph() self.pattern = PaddleGraph()
self.patterns = list() self.patterns = list()
self.rm_params = set()
def operate(self, graph, match_kind="topo"): def operate(self, graph, match_kind="topo"):
parameters = graph.parameters parameters = graph.parameters
...@@ -335,6 +336,8 @@ class FuseBase(object): ...@@ -335,6 +336,8 @@ class FuseBase(object):
subgraph = get_subgraph("", first_layer_id, graph) subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, parameters, match) self.insert_new_layer(subgraph, parameters, match)
self.delete_match(graph) self.delete_match(graph)
for param_name in self.rm_params:
parameters.pop(param_name)
graph.build() graph.build()
def perform_pattern_matcher(self, graph, match_kind="topo"): def perform_pattern_matcher(self, graph, match_kind="topo"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册