diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 557465df9dc60869f7fd4673a56faad4a01df317..51c60d9b5e9b527369256c94fe65fed93176bd66 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -672,7 +672,7 @@ class OpSet9(): perm = list(range(len(val_x.out_shapes[0]))) perm = [axis] + perm[:axis] + perm[axis + 1:] attr_trans = {'perm': perm} - name_trans = val_x.layer_name + '_trans' + name_trans = val_x.layer_name + '_transpose' node.fluid_code.add_layer( 'transpose', inputs=val_x, @@ -684,8 +684,12 @@ class OpSet9(): 'index': indices_reshape}, output=node, param_attr=None) + input_transpose = node.layer_name + '_transpose' node.fluid_code.add_layer( - 'transpose', inputs=node, output=node, param_attr=attr_trans) + 'transpose', + inputs=node, + output=input_transpose, + param_attr=attr_trans) val_x_shape = val_x.out_shapes[0] reshaped_shape = [] for i in perm: @@ -694,7 +698,7 @@ class OpSet9(): reshaped_shape.append(i) node.fluid_code.add_layer( 'reshape', - inputs=node, + inputs=input_transpose, output=node, param_attr={'shape': reshaped_shape}) @@ -748,17 +752,21 @@ class OpSet9(): } else: if starts.dtype != 'int32': + starts_cast = starts.layer_name + '_cast' node.fluid_code.add_layer( 'cast', inputs=starts, - output=starts, + output=starts_cast, param_attr={'dtype': string('int32')}) + attr['starts'] = starts_cast if ends.dtype != 'int32': + ends_cast = ens.layer_name + '_cast' node.fluid_code.add_layer( 'cast', inputs=ends, - output=ends, + output=ends_cast, param_attr={'dtype': string('int32')}) + attr['ends'] = ends_cast else: starts = node.get_attr('starts') ends = node.get_attr('ends')