diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index d562228ffd51bf7895582dc5af4c48dbd59a8aef..08107b55d40ace47e057580d2fcc93d5867bb439 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -406,7 +406,7 @@ class TFConverter(object): input_channels = weight_tensor_value.shape[2] # HWIO -> OIHW weight_tensor_value = weight_tensor_value.transpose(3, 2, 0, 1) - if input_shape[2] > 16 and input_shape[3] > 16: + if input_shape[1] > 16 and input_shape[2] > 16: G = np.array([ [1.0, 0.0, 0.0], [-2.0 / 9, -2.0 / 9, -2.0 / 9],