未验证 提交 00245cfd 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-ASP] Revise 4d tensor sparsity mask pattern for conv2d sparsity (#36054)

Sparse tensor core for convolution requires the input channel dimension is 2:4 structed sparse.
So we have to mask the input channel dimension for using sparse tensor core
上级 c38b0488
...@@ -518,9 +518,13 @@ def create_mask(tensor, func_name=MaskAlgo.MASK_1D, n=2, m=4): ...@@ -518,9 +518,13 @@ def create_mask(tensor, func_name=MaskAlgo.MASK_1D, n=2, m=4):
t = t.reshape(shape[0], shape[1]) t = t.reshape(shape[0], shape[1])
elif len(shape) == 3: elif len(shape) == 3:
t = t.reshape(shape[0] * shape[1], shape[2]) t = t.reshape(shape[0] * shape[1], shape[2])
# 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op # 4d-tensor conv (h, w, in, out) -> (h*w*out, in) in GemmConvKernel Op
elif len(shape) == 4: elif len(shape) == 4:
t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) t = t.transpose([0, 1, 3, 2]).reshape(shape[0] * shape[1] * shape[3],
shape[2])
mask = func(t, n=n, m=m)
return mask.reshape([shape[0], shape[1], shape[3],
shape[2]]).transpose([0, 1, 3, 2]).astype(dtype)
else: else:
raise ValueError("The dimension of input tensor is not supported in create_mask, " \ raise ValueError("The dimension of input tensor is not supported in create_mask, " \
"Only dimension < 4 is supported but got {}".format(len(shape))) "Only dimension < 4 is supported but got {}".format(len(shape)))
...@@ -572,9 +576,10 @@ def check_sparsity(tensor, func_name=CheckMethod.CHECK_1D, n=2, m=4): ...@@ -572,9 +576,10 @@ def check_sparsity(tensor, func_name=CheckMethod.CHECK_1D, n=2, m=4):
t = t.reshape(shape[0], shape[1]) t = t.reshape(shape[0], shape[1])
elif len(shape) == 3: elif len(shape) == 3:
t = t.reshape(shape[0] * shape[1], shape[2]) t = t.reshape(shape[0] * shape[1], shape[2])
# 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op # 4d-tensor conv (h, w, in, out) -> (h*w*out, in) in GemmConvKernel Op
elif len(shape) == 4: elif len(shape) == 4:
t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) t = t.transpose([0, 1, 3, 2]).reshape(
[shape[0] * shape[1] * shape[3], shape[2]])
else: else:
raise ValueError("The dimension of input tensor is not supported in create_mask, " \ raise ValueError("The dimension of input tensor is not supported in create_mask, " \
"Only dimension < 4 is supported but got {}".format(len(shape))) "Only dimension < 4 is supported but got {}".format(len(shape)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册