提交 1b7cdc86 编写于 作者: H hiranoaya

matmul tuning space fixed

上级 df57a6cf
......@@ -484,18 +484,39 @@ def _get_space_matmul_cube(op_desc: MatmulCubeDesc):
mmax = (m + 15) // 16
nmax = (n + 15) // 16
kmax = (k + 15) // 16
size_scale = 2
l1_max_size = (1024 * 1024) / size_scale
l0a_max_size = (64 * 1024) / size_scale
l0b_max_size = (64 * 1024) / size_scale
l0c_max_size = ((256 - 8) * 1024) / size_scale
double_buffer = True
mad_fp32 = True
l1_max_size = (1024 * 1024) # L1 MEM 1024KB
l0a_max_size = (64 * 1024) # L0A MEM 64KB
l0b_max_size = (64 * 1024) # L0B MEM 64KB
l0c_max_size = (256 * 1024) # L0C MEM 256KB
ub_max_size = ((256 - 8) * 1024) # UB MEM 248KB, 8KB reserved for compiler
if double_buffer:
l1_max_size = l1_max_size // 2
l0a_max_size = l0a_max_size // 2
l0b_max_size = l0b_max_size // 2
l0c_max_size = l0c_max_size // 2
ub_max_size = ub_max_size // 2
if mad_fp32:
l0c_max_size = l0c_max_size // 2
if op_desc.out_dtype == 'float32':
l0c_max_size = l0c_max_size / 2
ub_max_size = ub_max_size // 2
bypass_options = [0, 1, 2]
for bypass in bypass_options:
if (bypass == 2) and ((op_desc.adj_x == False and op_desc.left_format[0].lower() == 'n') or
(op_desc.adj_x == True and op_desc.left_format[0].lower() == 'z')):
continue
if (bypass == 1) and ((op_desc.adj_y == False and op_desc.right_format[0].lower() == 'z') or
(op_desc.adj_y == True and op_desc.right_format[0].lower() == 'n')):
continue
for k_l1 in range(1, kmax + 1):
if kmax % k_l1 != 0:
continue
......@@ -528,15 +549,18 @@ def _get_space_matmul_cube(op_desc: MatmulCubeDesc):
if m_l0 * 16 * n_l0 * 16 > l0c_max_size:
continue
if m_l0 * 16 * n_l0 * 16 > ub_max_size:
continue
if bypass == 2:
l1_size = n_l1 * 16 * k_l1 * 16
elif bypass == 1:
l1_size = m_l1 * 16 * k_l1 * 16
else:
l1_size = (m_l1 * 16 + n_l1 * 16) * k_l1 * 16
if l1_size > l1_max_size:
continue
if nmax == 1:
n_l1 = 0
n_l0 = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册