diff --git a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py index e47aadb8fe8bed5b6988c0ba18efcaee9183c1b0..5bab8940ed778e60d7f8e79c7f395208c7792bc2 100644 --- a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py @@ -1388,13 +1388,13 @@ class OpSet9(): "y": output_name + "__mul"}, outputs=[output_name]) else: - if mode == 'channel' and len(shape_slope) == 1: + if mode == 'channel': slope_data = _const_weight_or_none(val_slope) - self.weights[val_slope.name] = slope_data + if len(shape_slope) > 1: + self.weights[val_slope.name] = np.reshape(slope_data, shape_slope[0]) num_parameters = val_x.out_shapes[0][1] else: num_parameters = 1 - self.paddle_graph.add_layer( "paddle.nn.PReLU", inputs={"x": val_x.name},