提交 094329bd 编写于 作者: W wjj19950828

resolve conflict

...@@ -115,7 +115,7 @@ Aten: ...@@ -115,7 +115,7 @@ Aten:
| 121 | aten::repeat\_interleave | 122 | aten::maxpool1d | 123 | aten::frobenius\_norm | 124 | aten::format | | 121 | aten::repeat\_interleave | 122 | aten::maxpool1d | 123 | aten::frobenius\_norm | 124 | aten::format |
| 125 | aten::complex | 126 | aten::real | 127 | aten::imag | 128 | aten::fft\_rfftn | | 125 | aten::complex | 126 | aten::real | 127 | aten::imag | 128 | aten::fft\_rfftn |
| 129 | aten::fft\_irfftn | 130 | aten::hardsigmoid | 131 | aten::hardswish | 132 | aten::linear | | 129 | aten::fft\_irfftn | 130 | aten::hardsigmoid | 131 | aten::hardswish | 132 | aten::linear |
| 133 | aten::rsqrt | 134 | aten::replication\_pad1d | 135 | aten::full | 136 | aten::argmax | | 133 | aten::rsqrt | 134 | aten::replication\_pad1d | 135 | aten::full | 136 | aten::group\_norm |
| 137 | aten::argmax | 138 | aten::copy | | | | | | 137 | aten::argmax | 138 | aten::copy | | | | |
Prim: Prim:
......
...@@ -2732,6 +2732,59 @@ def aten_gt(mapper, graph, node): ...@@ -2732,6 +2732,59 @@ def aten_gt(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_group_norm(mapper, graph, node):
"""
TorchScript Code:
%input.81 : Tensor = aten::group_norm(%input.2, %25, %60, %59, %26, %30)
Parameter meaning:
%input.81 (Tensor): Output Tensor
%input.2 (Tensor): Input Tensor
%25 (Tensor): num_groups
%60 (Tensor): weight
%59 (Tensor): bias
%26 (Tensor): eps
%30 (bool): enabled cudnn
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("groupnorm", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# output list
current_outputs = [output_name]
# process Input Tensor
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0]
# input list
current_inputs = list(layer_inputs.values())
# process num_groups
layer_attrs['num_groups'] = mapper.attrs[inputs_name[1]]
# process weight
weights = mapper.pytorch_params[inputs_name[2]]
mapper.paddle_params[op_name + ".weight"] = weights
layer_attrs['num_channels'] = weights.shape[0]
# process bias
if inputs_name[2] in mapper.pytorch_params:
bias = mapper.pytorch_params[inputs_name[3]]
if bias is not None:
mapper.paddle_params[op_name + ".bias"] = bias
else:
mapper.paddle_params[op_name + ".bias"] = False
# process eps
layer_attrs["epsilon"] = mapper.attrs[inputs_name[4]]
graph.add_layer(
"paddle.nn.GroupNorm",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_gru(mapper, graph, node): def aten_gru(mapper, graph, node):
""" 构造门控循环单元网络(GRU)的PaddleLayer。 """ 构造门控循环单元网络(GRU)的PaddleLayer。
TorchScript示例: TorchScript示例:
...@@ -2931,45 +2984,38 @@ def aten_hardswish(mapper, graph, node): ...@@ -2931,45 +2984,38 @@ def aten_hardswish(mapper, graph, node):
def aten_index(mapper, graph, node): def aten_index(mapper, graph, node):
""" 构造选择元素的PaddleLayer。 """
TorchScript示例: TorchScript Code:
%1681 : Float = aten::index(%1653, %1680) %1681 : Float = aten::index(%1653, %1680)
参数含义: Parameter meaning:
%1681 (Tensor): 输出,选择后的Tensor。 %1681 (Tensor): Output Tensor
%1653 (Tensor): 需要选择的Tensor。 %1653 (Tensor): Input Tensor
%1680 (int): 选择的索引。 %1680 (int): Index
""" """
scope_name = mapper.normalize_scope_name(node) scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # output list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%1653 # process Input Tensor
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name) scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%1680 # process Index
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name) scope_name)
layer_inputs["index"] = inputs_name[1] layer_inputs["index"] = inputs_name[1]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.getitem", "prim.getitem",
inputs={"list": layer_inputs["index"]}, inputs={"list": layer_inputs["x"]},
outputs=[layer_inputs["index"]],
scope_name=scope_name,
index=0)
graph.add_layer(
"paddle.index_select",
inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) index=layer_inputs["index"])
return current_inputs, current_outputs return current_inputs, current_outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册