From bc5d32fe1abf9e7436532583060deec28b85364c Mon Sep 17 00:00:00 2001 From: Channingss Date: Mon, 10 Aug 2020 08:40:32 +0000 Subject: [PATCH] fix bug --- x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 557465d..51c60d9 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -672,7 +672,7 @@ class OpSet9(): perm = list(range(len(val_x.out_shapes[0]))) perm = [axis] + perm[:axis] + perm[axis + 1:] attr_trans = {'perm': perm} - name_trans = val_x.layer_name + '_trans' + name_trans = val_x.layer_name + '_transpose' node.fluid_code.add_layer( 'transpose', inputs=val_x, @@ -684,8 +684,12 @@ class OpSet9(): 'index': indices_reshape}, output=node, param_attr=None) + input_transpose = node.layer_name + '_transpose' node.fluid_code.add_layer( - 'transpose', inputs=node, output=node, param_attr=attr_trans) + 'transpose', + inputs=node, + output=input_transpose, + param_attr=attr_trans) val_x_shape = val_x.out_shapes[0] reshaped_shape = [] for i in perm: @@ -694,7 +698,7 @@ class OpSet9(): reshaped_shape.append(i) node.fluid_code.add_layer( 'reshape', - inputs=node, + inputs=input_transpose, output=node, param_attr={'shape': reshaped_shape}) @@ -748,17 +752,21 @@ class OpSet9(): } else: if starts.dtype != 'int32': + starts_cast = starts.layer_name + '_cast' node.fluid_code.add_layer( 'cast', inputs=starts, - output=starts, + output=starts_cast, param_attr={'dtype': string('int32')}) + attr['starts'] = starts_cast if ends.dtype != 'int32': + ends_cast = ens.layer_name + '_cast' node.fluid_code.add_layer( 'cast', inputs=ends, - output=ends, + output=ends_cast, param_attr={'dtype': string('int32')}) + attr['ends'] = ends_cast else: starts = node.get_attr('starts') ends = node.get_attr('ends') -- GitLab