diff --git a/python/paddle/fluid/contrib/sparsity/utils.py b/python/paddle/fluid/contrib/sparsity/utils.py index bb030cbac1beaf814987e5cf6a21075ff21d58ee..a72ea4d9b851083ba2565678bf6eb1992bf0f406 100644 --- a/python/paddle/fluid/contrib/sparsity/utils.py +++ b/python/paddle/fluid/contrib/sparsity/utils.py @@ -518,9 +518,13 @@ def create_mask(tensor, func_name=MaskAlgo.MASK_1D, n=2, m=4): t = t.reshape(shape[0], shape[1]) elif len(shape) == 3: t = t.reshape(shape[0] * shape[1], shape[2]) - # 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op + # 4d-tensor conv (h, w, in, out) -> (h*w*out, in) in GemmConvKernel Op elif len(shape) == 4: - t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) + t = t.transpose([0, 1, 3, 2]).reshape(shape[0] * shape[1] * shape[3], + shape[2]) + mask = func(t, n=n, m=m) + return mask.reshape([shape[0], shape[1], shape[3], + shape[2]]).transpose([0, 1, 3, 2]).astype(dtype) else: raise ValueError("The dimension of input tensor is not supported in create_mask, " \ "Only dimension < 4 is supported but got {}".format(len(shape))) @@ -572,9 +576,10 @@ def check_sparsity(tensor, func_name=CheckMethod.CHECK_1D, n=2, m=4): t = t.reshape(shape[0], shape[1]) elif len(shape) == 3: t = t.reshape(shape[0] * shape[1], shape[2]) - # 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op + # 4d-tensor conv (h, w, in, out) -> (h*w*out, in) in GemmConvKernel Op elif len(shape) == 4: - t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) + t = t.transpose([0, 1, 3, 2]).reshape( + [shape[0] * shape[1] * shape[3], shape[2]]) else: raise ValueError("The dimension of input tensor is not supported in create_mask, " \ "Only dimension < 4 is supported but got {}".format(len(shape)))