提交 e98f6d90 编写于 作者: S SunAhong1993

fix the caffe leaklyrelu and add aten_prelu

上级 fa0df3cf
......@@ -571,7 +571,7 @@ class CaffeOpMapper(OpMapper):
if params.HasField('negative_slope') and params.negative_slope != 0:
negative_slope = float(params.negative_slope)
layer_attrs = {'alpha': negative_slope}
layer_attrs = {'negative_slope': negative_slope}
self.paddle_graph.add_layer(
"paddle.nn.LeakyReLU",
inputs={"input": input.name},
......
......@@ -3353,6 +3353,42 @@ def aten_pow(mapper, graph, node):
return current_inputs, current_outputs
def aten_prelu(mapper, graph, node):
""" 构造prelu激活的PaddleLayer。
TorchScript示例:
%result.3 : aten::prelu(%input.150, %999)
参数含义:
%result.3 (Tensor): 输出,prelu后的结果。
%input.150 (Tensor): 需要prelu的Tensor。
%999 (Tnsor): 权重。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("relu", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%result.150
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%999
weight = mapper.pytorch_params[inputs_name[1]]
mapper.paddle_params[op_name + "._weight"] = weight
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"paddle.nn.PReLU",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
num_parameters=weight.shape[0])
return current_inputs, current_outputs
def aten_relu(mapper, graph, node):
""" 构造ReLU激活的PaddleLayer。
......
......@@ -27,6 +27,7 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Linear": "linear",
"paddle.nn.Conv2DTranspose": "conv",
"paddle.nn.LSTM": "lstm",
"paddle.nn.PReLU": "prelu",
"paddle.nn.ReLU": "relu",
"paddle.nn.ReLU6": "relu",
"paddle.nn.Softmax": "softmax",
......@@ -41,7 +42,7 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.GELU": "gelu",
"paddle.nn.Hardtanh": "tanh",
"paddle.nn.LeakyReLU": "leakly_relu"}
NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:7]
NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:8]
def rename_layers(layers, param_tree=None, is_rename_module=False):
""" 对子模块的输入输出等进行重命名。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册