From 00245cfd2e5fe175a80d13a67b5c75e27930ce59 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Mon, 11 Oct 2021 18:40:07 +0800 Subject: [PATCH] [Paddle-ASP] Revise 4d tensor sparsity mask pattern for conv2d sparsity (#36054) Sparse tensor core for convolution requires the input channel dimension is 2:4 structed sparse. So we have to mask the input channel dimension for using sparse tensor core --- python/paddle/fluid/contrib/sparsity/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/contrib/sparsity/utils.py b/python/paddle/fluid/contrib/sparsity/utils.py index bb030cbac1b..a72ea4d9b85 100644 --- a/python/paddle/fluid/contrib/sparsity/utils.py +++ b/python/paddle/fluid/contrib/sparsity/utils.py @@ -518,9 +518,13 @@ def create_mask(tensor, func_name=MaskAlgo.MASK_1D, n=2, m=4): t = t.reshape(shape[0], shape[1]) elif len(shape) == 3: t = t.reshape(shape[0] * shape[1], shape[2]) - # 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op + # 4d-tensor conv (h, w, in, out) -> (h*w*out, in) in GemmConvKernel Op elif len(shape) == 4: - t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) + t = t.transpose([0, 1, 3, 2]).reshape(shape[0] * shape[1] * shape[3], + shape[2]) + mask = func(t, n=n, m=m) + return mask.reshape([shape[0], shape[1], shape[3], + shape[2]]).transpose([0, 1, 3, 2]).astype(dtype) else: raise ValueError("The dimension of input tensor is not supported in create_mask, " \ "Only dimension < 4 is supported but got {}".format(len(shape))) @@ -572,9 +576,10 @@ def check_sparsity(tensor, func_name=CheckMethod.CHECK_1D, n=2, m=4): t = t.reshape(shape[0], shape[1]) elif len(shape) == 3: t = t.reshape(shape[0] * shape[1], shape[2]) - # 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op + # 4d-tensor conv (h, w, in, out) -> (h*w*out, in) in GemmConvKernel Op elif len(shape) == 4: - t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) + t = t.transpose([0, 1, 3, 2]).reshape( + [shape[0] * shape[1] * shape[3], shape[2]]) else: raise ValueError("The dimension of input tensor is not supported in create_mask, " \ "Only dimension < 4 is supported but got {}".format(len(shape))) -- GitLab