提交 14889935 编写于 作者: S SunAhong1993

add prelu

上级 ba5a3314
...@@ -1359,21 +1359,49 @@ class OpSet9(): ...@@ -1359,21 +1359,49 @@ class OpSet9():
elif len(shape_slope) > 2: elif len(shape_slope) > 2:
raise Exception("The 'element' mode is not supported yet!") raise Exception("The 'element' mode is not supported yet!")
if mode == 'channel' and len(shape_slope) == 1: if mode == "element":
# paddle params shape need be [1, channel] self.paddle_graph.add_layer(
slope_data = _const_weight_or_none(val_slope) "paddle.zeros",
slope_data = np.reshape(slope_data, [1] + shape_slope) inputs={},
self.weights[val_slope.name] = slope_data outputs=[output_name + "__zeros"],
num_parameters = val_x.out_shapes[0][1] shape=shape_slope,
dtype=string(node.dtype))
self.paddle_graph.add_layer(
"paddle.maximum",
inputs={"x": val_x.name,
"y": output_name + "__zeros"},
outputs=[output_name + "__max"])
self.paddle_graph.add_layer(
"paddle.minimum",
inputs={"x": val_x.name,
"y": output_name + "__zeros"},
outputs=[output_name + "__max"])
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={"x": val_slope.name,
"y": output_name + "__min"},
outputs=[output_name + "__mul"])
self.paddle_graph.add_layer(
"paddle.add",
inputs={"x": output_name + "__max",
"y": output_name + "__mul"},
outputs=[output_name])
else: else:
num_parameters = 1 if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.name] = slope_data
num_parameters = val_x.out_shapes[0][1]
else:
num_parameters = 1
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.nn.PReLU", "paddle.nn.PReLU",
inputs={"x": val_x.name}, inputs={"x": val_x.name},
outputs=layer_outputs, outputs=layer_outputs,
num_parameters=num_parameters, num_parameters=num_parameters,
weight_attr=string(val_slope.name)) weight_attr=string(val_slope.name))
@print_mapping_info @print_mapping_info
def Squeeze(self, node): def Squeeze(self, node):
......
...@@ -1336,17 +1336,25 @@ class OpSet9(): ...@@ -1336,17 +1336,25 @@ class OpSet9():
elif len(shape_slope) > 2: elif len(shape_slope) > 2:
raise Exception("The 'element' mode is not supported yet!") raise Exception("The 'element' mode is not supported yet!")
if mode == 'channel' and len(shape_slope) == 1: if mode == "element":
# paddle params shape need be [1, channel] self.paddle_graph.add_layer(
slope_data = _const_weight_or_none(val_slope) "paddle.static.nn.prelu",
slope_data = np.reshape(slope_data, [1] + shape_slope) inputs={"x": val_x.name,
self.params[val_slope.name] = slope_data "param_attr": val_slope.name},
outputs=[node.name],
self.paddle_graph.add_layer( mode="element")
"paddle.nn.functional.prelu", else:
inputs={"x": val_x.name, if mode == 'channel' and len(shape_slope) == 1:
"weight": val_slope.name}, # paddle params shape need be [1, channel]
outputs=[node.name]) slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.params[val_slope.name] = slope_data
self.paddle_graph.add_layer(
"paddle.nn.functional.prelu",
inputs={"x": val_x.name,
"weight": val_slope.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Squeeze(self, node): def Squeeze(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册