From 1372a6d758c24aeef4dfe77355b079023f7715f8 Mon Sep 17 00:00:00 2001 From: Channingss Date: Mon, 12 Oct 2020 09:42:27 +0000 Subject: [PATCH] [ONNX] fix bug of prelu --- x2paddle/convert.py | 1 + x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index d5c78e7..86f6210 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -190,6 +190,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False): mapper = ONNXOpMapper(model) print("Model optimizing ...") optimizer = ONNXOptimizer(mapper) + optimizer.delete_redundance_code() print("Model optimized.") print("Paddle model and code generating ...") diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 3808755..b36560c 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1256,10 +1256,19 @@ class OpSet9(): mode = 'channel' shape_slope = val_slope.out_shapes[0] - if len(shape_slope) == 1: + + if shape_slope == [1]: mode = 'all' elif len(shape_slope) > 2: mode = 'element' + + if mode == 'channel' and len(shape_slope) == 1: + # paddle params shape need be [1, channel] + slope_data = _const_weight_or_none(val_slope) + slope_data = np.reshape(slope_data, [1] + shape_slope) + self.weights[val_slope.layer_name] = slope_data + + self.omit_nodes.append(val_slope.layer_name) attr = { "param_attr": string(val_slope.layer_name), 'mode': string(mode) -- GitLab