提交 4d151cf8 编写于 作者: S SunAhong1993

add aten and comment

上级 e1dda433
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -1044,6 +1044,9 @@ def aten_embedding(mapper, graph, node): ...@@ -1044,6 +1044,9 @@ def aten_embedding(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入2,即%45 # 处理输入2,即%45
if mapper.attrs[inputs_name[2]] == -1:
layer_attrs["padding_idx"] = None
else:
layer_attrs["padding_idx"] = mapper.attrs[inputs_name[2]] layer_attrs["padding_idx"] = mapper.attrs[inputs_name[2]]
# 处理输入4,即%46 # 处理输入4,即%46
# layer_attrs["sparse"] = mapper.attrs[inputs_name[4]] # layer_attrs["sparse"] = mapper.attrs[inputs_name[4]]
...@@ -2933,6 +2936,44 @@ def aten_softplus(mapper, graph, node): ...@@ -2933,6 +2936,44 @@ def aten_softplus(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_squeeze(mapper, graph, node):
""" 构造删除位数为1的维度的PaddleLayer。
TorchScript示例:
%12 : Tensor = aten::squeeze(%start_logits.1, %4)
参数含义:
%12 (Tensor): 输出,删除维度后的Tensor。
%start_logits.1 (Tensor): 需要删除维度的Tensor。
%4 (int): 维度。
"""
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%start_logits.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%4
if inputs_name[1] in mapper.attrs:
layer_attrs["axis"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs)
layer_inputs["axis"] = inputs_name[1]
current_inputs.append(inputs_name[1])
graph.add_layer(
"paddle.tensor.squeeze",
inputs=layer_inputs,
outputs=layer_outputs,
**layer_attrs)
return current_inputs, current_outputs
def aten_stack(mapper, graph, node): def aten_stack(mapper, graph, node):
""" 构造堆叠Tensor的PaddleLayer。 """ 构造堆叠Tensor的PaddleLayer。
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册