diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index 5e1a54b2dab5f9a4537e7c53fecfa7d65a74e4b5..81079df9da0a6faf12d7a476749e52835a4f27aa 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -220,7 +220,10 @@ class TFOpMapperNHWC(OpMapper): block_size = node.get_attr("block_size") data_format = node.get_attr("data_format").decode() - n, h, w, c = input.out_shapes[0] + if data_format == "NHWC": + n, h, w, c = input.out_shapes[0] + else: + n, c, h, w = input.out_shapes[0] input_name = input.name if data_format == "NHWC":