未验证 提交 0ac82113 编写于 作者: J Jason 提交者: GitHub

Merge pull request #418 from Channingss/fix_bug_prelu

[ONNX] prelu support slope len(shape) ==1, but shape !=[1]
......@@ -177,6 +177,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
mapper = ONNXOpMapper(model)
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
print("Model optimized.")
print("Paddle model and code generating ...")
......
......@@ -550,8 +550,6 @@ class OpSet9():
def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name)
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones'
attr_ones = {
......@@ -1256,10 +1254,19 @@ class OpSet9():
mode = 'channel'
shape_slope = val_slope.out_shapes[0]
if len(shape_slope) == 1:
if shape_slope == [1]:
mode = 'all'
elif len(shape_slope) > 2:
mode = 'element'
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.layer_name] = slope_data
self.omit_nodes.append(val_slope.layer_name)
attr = {
"param_attr": string(val_slope.layer_name),
'mode': string(mode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册