diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 45fd2e8c570e43d0b77b559181bf363d09c80011..4d22d0246cc0da6462245bf01187d5de355eb235 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -718,7 +718,7 @@ class TFOpMapper(OpMapper): if input.tf_data_format == "NHWC": if len(input.out_shapes[0]) == 4: expand_times = [expand_times[i] for i in [0, 3, 1, 2]] - elif len(input.out_shape[0]) == 3: + elif len(input.out_shapes[0]) == 3: expand_times = [expand_times[i] for i in [2, 0, 1]] for i in range(len(expand_times)): if expand_times[i] < 0: