提交 1372a6d7 编写于 作者: C Channingss

[ONNX] fix bug of prelu

上级 423ceb0d
...@@ -190,6 +190,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -190,6 +190,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
mapper = ONNXOpMapper(model) mapper = ONNXOpMapper(model)
print("Model optimizing ...") print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper) optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
print("Model optimized.") print("Model optimized.")
print("Paddle model and code generating ...") print("Paddle model and code generating ...")
......
...@@ -1256,10 +1256,19 @@ class OpSet9(): ...@@ -1256,10 +1256,19 @@ class OpSet9():
mode = 'channel' mode = 'channel'
shape_slope = val_slope.out_shapes[0] shape_slope = val_slope.out_shapes[0]
if len(shape_slope) == 1:
if shape_slope == [1]:
mode = 'all' mode = 'all'
elif len(shape_slope) > 2: elif len(shape_slope) > 2:
mode = 'element' mode = 'element'
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.layer_name] = slope_data
self.omit_nodes.append(val_slope.layer_name)
attr = { attr = {
"param_attr": string(val_slope.layer_name), "param_attr": string(val_slope.layer_name),
'mode': string(mode) 'mode': string(mode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册