提交 14889935 编写于 作者: S SunAhong1993

add prelu

上级 ba5a3314
......@@ -1359,6 +1359,34 @@ class OpSet9():
elif len(shape_slope) > 2:
raise Exception("The 'element' mode is not supported yet!")
if mode == "element":
self.paddle_graph.add_layer(
"paddle.zeros",
inputs={},
outputs=[output_name + "__zeros"],
shape=shape_slope,
dtype=string(node.dtype))
self.paddle_graph.add_layer(
"paddle.maximum",
inputs={"x": val_x.name,
"y": output_name + "__zeros"},
outputs=[output_name + "__max"])
self.paddle_graph.add_layer(
"paddle.minimum",
inputs={"x": val_x.name,
"y": output_name + "__zeros"},
outputs=[output_name + "__max"])
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={"x": val_slope.name,
"y": output_name + "__min"},
outputs=[output_name + "__mul"])
self.paddle_graph.add_layer(
"paddle.add",
inputs={"x": output_name + "__max",
"y": output_name + "__mul"},
outputs=[output_name])
else:
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
......
......@@ -1336,6 +1336,14 @@ class OpSet9():
elif len(shape_slope) > 2:
raise Exception("The 'element' mode is not supported yet!")
if mode == "element":
self.paddle_graph.add_layer(
"paddle.static.nn.prelu",
inputs={"x": val_x.name,
"param_attr": val_slope.name},
outputs=[node.name],
mode="element")
else:
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册