提交 6cdf2161 编写于 作者: S SunAhong1993

modify onnx

上级 fc6042a3
...@@ -1293,25 +1293,23 @@ class OpSet9(): ...@@ -1293,25 +1293,23 @@ class OpSet9():
if shape_slope == [1]: if shape_slope == [1]:
mode = 'all' mode = 'all'
elif len(shape_slope) > 2: elif len(shape_slope) > 2:
mode = 'element' raise Exception("The 'element' mode is not supported yet!")
if mode == 'channel' and len(shape_slope) == 1: if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel] # paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope) slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope) slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.layer_name] = slope_data self.weights[val_slope.layer_name] = slope_data
num_parameters = val_x.out_shapes[0][1]
else:
num_parameters = 1
layer_attrs = {
"param_attr": string(val_slope.layer_name),
'mode': string(mode),
"channel": val_x.out_shapes[0][1] if mode == "channel" else None,
"input_shape": val_x.out_shapes[0] if mode == "element" else None,
}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.nn.PReLU", "paddle.nn.PReLU",
inputs={"x": self.get_node_name(val_x)}, inputs={"x": self.get_node_name(val_x)},
outputs=layer_outputs, outputs=layer_outputs,
**layer_attrs) num_parameters=num_parameters,
weight_attr=string(val_slope.layer_name))
@print_mapping_info @print_mapping_info
def Squeeze(self, node): def Squeeze(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册