matmul_run.py 16.0 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
import logging
from enum import Enum
from gen_random import random_gaussian
import numpy as np
import akg.backend as cce
from akg.utils import kernel_exec as utils
W
wangzhuo325 已提交
22
from akg.ops.nn import matmul
C
ckey_Dou 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
from base import get_rtol_atol
from tensorio import compare_tensor


logging.basicConfig(level=logging.DEBUG)


class MatmulType(Enum):
    gemm = 1
    gevm = 2
    gemv = 3


def get_name(caseIndex=1, name="leftMatrix", M=0, K=0, N=0, adj_x=False, adj_y=False):
    res = "{}_{}_{}_{}_{}_{}_{}.bin".format(caseIndex, name, M, K, N, adj_x, adj_y)
    return res


def get_shape(name="leftMatrix", M=0, K=0, N=0, batch_tuple=(1,), adj_x=False, adj_y=False):
    res_shape = ()
    if name == "leftMatrix":
        if adj_x:
            res_shape = batch_tuple + (K // cce.BLOCK_REDUCE, M // cce.BLOCK_IN, cce.BLOCK_REDUCE, cce.BLOCK_IN)
        else:
            res_shape = batch_tuple + (M // cce.BLOCK_IN, K // cce.BLOCK_REDUCE, cce.BLOCK_IN, cce.BLOCK_REDUCE)

    if name == "rightMatrix":
        if adj_y:
            res_shape = batch_tuple + (N // cce.BLOCK_OUT, K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE, cce.BLOCK_OUT)
        else:
            res_shape = batch_tuple + (K // cce.BLOCK_REDUCE, N // cce.BLOCK_OUT, cce.BLOCK_OUT, cce.BLOCK_REDUCE)

    if name == "result":
        res_shape = batch_tuple + (N // cce.BLOCK_OUT, M // cce.BLOCK_IN, cce.BLOCK_IN, cce.BLOCK_OUT)
    return res_shape


def get_shapes(batch_tuple, M, K, N, trans_data=False, trans_weight=False):
    shape_x = batch_tuple + (M, K)
    if trans_data:
        shape_x = batch_tuple + (K, M)
    shape_y = batch_tuple + (K, N)
    if trans_weight:
        shape_y = batch_tuple + (N, K)
    return shape_x, shape_y


C
chenlei_autodiff 已提交
70 71
def getMatmulType(m, n):
    matmul_type = MatmulType.gemm
C
ckey_Dou 已提交
72
    if m // cce.BLOCK_IN == 0:
C
chenlei_autodiff 已提交
73
        matmul_type = MatmulType.gevm
C
ckey_Dou 已提交
74
    elif n == 1:
C
chenlei_autodiff 已提交
75 76
        matmul_type = MatmulType.gemv
    return matmul_type
C
ckey_Dou 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102


def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_weight=False, output_format=None):
    """
    implementation for the batch matmul
    :param matrix_a: (batch1, batch2, ..., M, K)
    :param matrix_b: (batch1, batch2, ..., K, N)
    :return:
    (batch1, batch2, ..., M, N)
    """
    batch_len = len(batch_tuple)
    if trans_data:
        matrix_a = matrix_a.transpose(tuple(range(batch_len)) + (batch_len + 1, batch_len))
    if trans_weight:
        matrix_b = matrix_b.transpose(tuple(range(batch_len)) + (batch_len + 1, batch_len))

    mul = 1
    for i in batch_tuple:
        mul = mul * i
    reshape_x = matrix_a.reshape(mul, M, K)
    reshape_y = matrix_b.reshape(mul, K, N)
    flatten_shape = (mul, M, N)
    out = np.zeros(flatten_shape, dtype=np.float16)
    for b in range(mul):
        out[b, :] = np.dot(reshape_x[b, :], reshape_y[b, :])
        #out[b,:] = np.matmul(reshape_x[b,:], reshape_y[b,:])
C
chenlei_autodiff 已提交
103
    matmul_type = getMatmulType(M, N)
C
ckey_Dou 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    out_shape = ()
    if matmul_type == MatmulType.gemm:
        out_shape = batch_tuple + (M // cce.BLOCK_IN, cce.BLOCK_IN, N // cce.BLOCK_OUT, cce.BLOCK_OUT)
    elif matmul_type == MatmulType.gevm:
        out_shape = batch_tuple + (1, M % cce.BLOCK_IN, N // cce.BLOCK_OUT, cce.BLOCK_OUT)
    elif matmul_type == MatmulType.gemv:
        out_shape = batch_tuple + (M // cce.BLOCK_IN, cce.BLOCK_IN, 1, N % cce.BLOCK_OUT)
    logging.debug(out_shape)
    # No Mo Mi Ni
    trans = tuple(range(batch_len)) + (batch_len + 2, batch_len, batch_len + 1, batch_len + 3)
    if output_format == "zZ":
        trans = tuple(range(batch_len)) + (batch_len, batch_len + 2, batch_len + 1, batch_len + 3)
    if matmul_type == MatmulType.gemv:
        # use the transpose of out
        trans = tuple(range(batch_len)) + (batch_len, batch_len + 2, batch_len + 3, batch_len + 1)
    res = out.reshape(out_shape).transpose(trans).copy()
    return res


def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
W
wangzhuo325 已提交
124
            dtype="float16", bias_dtype="float16", out_dtype="float16", bias=0, left_format="zZ", right_format="nZ", output_format="zN"):
C
ckey_Dou 已提交
125 126 127 128 129 130 131 132 133 134
    shape_x, shape_y = get_shapes(batch_tuple, M, K, N, trans_data, trans_weight)
    matrix_a = random_gaussian(shape_x, miu=0.1, sigma=0.01).astype(dtype)
    matrix_b = random_gaussian(shape_y, miu=0.1, sigma=0.01).astype(dtype)
    # matrix_a = np.ones(shape_x, dtype=np.float16)
    # matrix_b = np.ones(shape_y, dtype=np.float16)

    # this change is for gen data speed
    matrix_a_for_np = matrix_a.astype(np.float32)
    matrix_b_for_np = matrix_b.astype(np.float32)

C
chenlei_autodiff 已提交
135
    matmul_type = getMatmulType(M, N)
C
ckey_Dou 已提交
136 137 138 139
    out = np_matmul(matrix_a_for_np, matrix_b_for_np, batch_tuple, M, K, N, trans_data, trans_weight, output_format).astype(out_dtype)
    if dtype == "float16":
        out.astype(np.float16)

W
wangzhuo325 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152
    bias_shape = (N,)
    bias_data = np.full(bias_shape, np.nan, bias_dtype)
    if bias != 0:
        bias_data = random_gaussian(bias_shape, miu=0.5, sigma=0.01).astype(bias_dtype)
        bias_reshape = (N // cce.BLOCK_OUT, 1, 1, cce.BLOCK_OUT)
        if output_format == "zZ":
            bias_reshape = (1, N // cce.BLOCK_OUT, 1, cce.BLOCK_OUT)
        bias_data_reshaped = bias_data.reshape(bias_reshape)
        if bias_dtype != out_dtype:
            out = out.astype(np.float32) + bias_data_reshaped.astype(np.float32)
            out = out.astype(out_dtype)
        else:
            out = out + bias_data_reshaped
C
ckey_Dou 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193

    shape_x = ()
    shape_y = ()
    if matmul_type == MatmulType.gemm:
        shape_x = (M // cce.BLOCK_IN, cce.BLOCK_IN, K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE)
        if trans_data:
            shape_x = (K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE, M // cce.BLOCK_IN, cce.BLOCK_IN)
        shape_y = (K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE, N // cce.BLOCK_OUT, cce.BLOCK_OUT)
        if trans_weight:
            shape_y = (N // cce.BLOCK_OUT, cce.BLOCK_OUT, K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE)
    elif matmul_type == MatmulType.gevm:
        shape_x = (1, M % cce.BLOCK_IN, K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE)
        shape_y = (K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE, N // cce.BLOCK_OUT, cce.BLOCK_OUT)
    elif matmul_type == MatmulType.gemv:
        # use traspose(b) transpose(a)
        shape_x = (M // cce.BLOCK_IN, cce.BLOCK_IN, K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE)
        shape_y = (K // cce.BLOCK_REDUCE, cce.BLOCK_REDUCE, 1, N % cce.BLOCK_OUT)

    batch_len = len(batch_tuple)
    # left_format zZ
    if left_format == "zZ":
        trans_x = tuple(range(batch_len)) + (batch_len + 0, batch_len + 2, batch_len + 1, batch_len + 3)
    elif left_format == "zN":
        trans_x = tuple(range(batch_len)) + (batch_len + 2, batch_len + 0, batch_len + 1, batch_len + 3)
    # right_format nZ
    if right_format == "nZ":
        trans_y = tuple(range(batch_len)) + (batch_len + 0, batch_len + 2, batch_len + 3, batch_len + 1)
    elif right_format == "zZ":
        trans_y = tuple(range(batch_len)) + (batch_len + 0, batch_len + 2, batch_len + 1, batch_len + 3)
    elif right_format == "zN":
        trans_y = tuple(range(batch_len)) + (batch_len + 2, batch_len + 0, batch_len + 1, batch_len + 3)
    fractal_a = matrix_a.reshape(batch_tuple + shape_x).transpose(trans_x).copy()
    fractal_b = matrix_b.reshape(batch_tuple + shape_y).transpose(trans_y).copy()
    if matmul_type == MatmulType.gemv:
        trans_y = tuple(range(batch_len)) + (batch_len + 2, batch_len + 0, batch_len + 3, batch_len + 1)
        trans_x = tuple(range(batch_len)) + (batch_len + 2, batch_len + 0, batch_len + 1, batch_len + 3)
        fractal_a = matrix_b.reshape(batch_tuple + shape_y).transpose(trans_y).copy()
        fractal_b = matrix_a.reshape(batch_tuple + shape_x).transpose(trans_x).copy()
    return fractal_a, fractal_b, out, bias_data


W
wangzhuo325 已提交
194
def matmul_data(batch_tuple, M, K, N, dtype, bias_dtype, out_dtype, bias, adj_x, adj_y, left_format=None, right_format=None, output_format=None, debug_logging=False):
C
ckey_Dou 已提交
195 196 197 198 199 200
    m_x = ()
    m_y = ()
    bench_mark = ()
    bias_data = ()
    logging.debug("gen data start!")
    a = datetime.now()
W
wangzhuo325 已提交
201
    m_x, m_y, bench_mark, bias_data = genData(batch_tuple, M, K, N, adj_x, adj_y, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format)
C
ckey_Dou 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
    b = datetime.now()
    logging.debug((b - a).seconds)
    logging.debug("gen data end!")

    if debug_logging:
        logging.debug("m_x shape:{}".format(m_x.shape))
        logging.debug("m_y shape:{}".format(m_y.shape))
        logging.debug(type(m_x))
        logging.debug("bench_mark shape: {}".format(bench_mark.shape))

    return m_x, m_y, bench_mark, bias_data


def extract_dim(shape_x, shape_y, adj_x, adj_y):
    rank = len(shape_x)
    m = shape_x[-2] if adj_x == False else shape_x[-1]
    k = shape_x[-1] if adj_x == False else shape_x[-2]
    n = shape_y[-1] if adj_y == False else shape_y[-2]
    batch_tuple = shape_x[:-2] if rank > 2 else (1,)
    return batch_tuple, m, k, n


def reduce_data(reduce_type):
    res = cce.BLOCK_IN
    if reduce_type == "in":
        res = cce.BLOCK_IN
    elif reduce_type == "out":
        res = cce.BLOCK_OUT
    elif reduce_type == "reduce":
        res = cce.BLOCK_REDUCE
    return res


C
chenlei_autodiff 已提交
235
def get_fractal_shape(dim1, dim2, reduce1="in", reduce2="reduce", matrix_format="zZ"):
C
ckey_Dou 已提交
236 237 238
    result = ()
    dim1_reduce = reduce_data(reduce1)
    dim2_reduce = reduce_data(reduce2)
C
chenlei_autodiff 已提交
239
    if matrix_format == "zZ":
C
ckey_Dou 已提交
240
        result = (dim1 // dim1_reduce, dim2 // dim2_reduce, dim1_reduce, dim2_reduce)
C
chenlei_autodiff 已提交
241
    elif matrix_format == "nZ":
C
ckey_Dou 已提交
242
        result = (dim1 // dim1_reduce, dim2 // dim2_reduce, dim2_reduce, dim1_reduce)
C
chenlei_autodiff 已提交
243
    elif matrix_format == "nN":
C
ckey_Dou 已提交
244
        result = (dim2 // dim2_reduce, dim1 // dim1_reduce, dim2_reduce, dim1_reduce)
C
chenlei_autodiff 已提交
245
    elif matrix_format == "zN":
C
ckey_Dou 已提交
246 247 248 249 250 251
        result = (dim2 // dim2_reduce, dim1 // dim1_reduce, dim1_reduce, dim2_reduce)

    return result


def get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias, left_format="zZ", right_format="nZ", out_format="zN"):
C
chenlei_autodiff 已提交
252
    matmul_type = getMatmulType(m, n)
C
ckey_Dou 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    if matmul_type == MatmulType.gemm:
        # left_format zZ process
        if left_format == "zZ":
            shape_xx = batch_tuple + get_fractal_shape(m, k, "in", "reduce", "zZ")
            if adj_x:
                shape_xx = batch_tuple + get_fractal_shape(m, k, "in", "reduce", "nN")
        # left_format zN process
        elif left_format == "zN":
            shape_xx = batch_tuple + get_fractal_shape(m, k, "in", "reduce", "zN")
            if adj_x:
                shape_xx = batch_tuple + get_fractal_shape(m, k, "in", "reduce", "nZ")
        else:
            raise RuntimeError("Error: unsupport left matrix format: %s" % left_format)

        # right_format nZ
        if right_format == "nZ":
            shape_yy = batch_tuple + get_fractal_shape(k, n, "reduce", "out", "nZ")
            if adj_y:
                shape_yy = batch_tuple + get_fractal_shape(k, n, "reduce", "out", "zN")
        # right_format zZ
        elif right_format == "zZ":
            shape_yy = batch_tuple + get_fractal_shape(k, n, "reduce", "out", "zZ")
            if adj_y:
                shape_yy = batch_tuple + get_fractal_shape(k, n, "reduce", "out", "nN")
        elif right_format == "zN":
            shape_yy = batch_tuple + get_fractal_shape(k, n, "reduce", "out", "zN")
            if adj_y:
                shape_yy = batch_tuple + get_fractal_shape(k, n, "reduce", "out", "nZ")
        else:
            raise RuntimeError("Error: unsupport right matrix format: %s" % right_format)

        # output_format zN
        # output_shape = batch_tuple + (n//cce.BLOCK_OUT, m//cce.BLOCK_IN, cce.BLOCK_IN, cce.BLOCK_OUT)
        if out_format == "zN":
            output_shape = batch_tuple + get_fractal_shape(m, n, "in", "out", "zN")
        elif out_format == "zZ":
            output_shape = batch_tuple + get_fractal_shape(m, n, "in", "out", "zZ")
        else:
            raise RuntimeError("Error: unsupport output matrix format: %s" % out_format)

    elif matmul_type == MatmulType.gevm:
        shape_xx = batch_tuple + (1, k // cce.BLOCK_REDUCE, m % cce.BLOCK_IN, cce.BLOCK_REDUCE)
        shape_yy = batch_tuple + (k // cce.BLOCK_REDUCE, n // cce.BLOCK_OUT, cce.BLOCK_OUT, cce.BLOCK_REDUCE)
        output_shape = batch_tuple + (n // cce.BLOCK_OUT, 1, m % cce.BLOCK_IN, cce.BLOCK_OUT)
    elif matmul_type == MatmulType.gemv:
        # transpose of b * transpose of a
        shape_xx = batch_tuple + (1, k // cce.BLOCK_REDUCE, n % cce.BLOCK_IN, cce.BLOCK_REDUCE)
        shape_yy = batch_tuple + (k // cce.BLOCK_REDUCE, m // cce.BLOCK_OUT, cce.BLOCK_OUT, cce.BLOCK_REDUCE)
        output_shape = batch_tuple + (m // cce.BLOCK_OUT, 1, n % cce.BLOCK_IN, cce.BLOCK_OUT)

    if bias == 1:
W
wangzhuo325 已提交
304
        bias_shape_nc1hwc0 = (n,)
C
ckey_Dou 已提交
305 306 307 308 309
    else:
        bias_shape_nc1hwc0 = None 
    return shape_xx, shape_yy, bias_shape_nc1hwc0, output_shape, k


W
wangzhuo325 已提交
310
def matmul_execute(shape_x, shape_y, bias, left_format, right_format, out_format, adj_x, adj_y, dtype, bias_dtype, out_dtype, kernel_name, attrs):
C
ckey_Dou 已提交
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
    '''
    There are four types of fractal format in Davinci core: zZ, zN, nZ, nN
    general matmul format
    left_trans: False right_trans False: zZ * nZ = zN
    left_trans: True  right_trans False: nN * nZ = zN
    left_trans: False right_trans True : zZ * zN = zN
    left_trans: True  right_trans True : nN * zN = zN

    Now we need to support: zN * nZ = zN
    use left_format to specify, left matrix data format
    use right_format to specify, right matrix data format
    '''
    batch_tuple, m, k, n = extract_dim(shape_x, shape_y, adj_x, adj_y)
    m = (m + 15) // 16 * 16
    n = (n + 15) // 16 * 16
    k = (k + 15) // 16 * 16
    shape_xx, shape_yy, bias_shape, out_shape, k = get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias, left_format, right_format, out_format)
W
wangzhuo325 已提交
328
    mod = matmul_compile(shape_x, shape_y, bias, left_format, right_format, out_format, adj_x, adj_y, dtype, bias_dtype, out_dtype, kernel_name, attrs)
C
ckey_Dou 已提交
329
    # Generate data
W
wangzhuo325 已提交
330
    m_x, m_y, bench_mark, bias_data = matmul_data(batch_tuple, m, k, n, dtype, bias_dtype, out_dtype, bias, adj_x, adj_y, left_format, right_format, out_format)
C
ckey_Dou 已提交
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345

    # mod launch
    output = np.full(out_shape, np.nan, out_dtype)
    if bias == 0:
        output = utils.mod_launch(mod, (m_x, m_y, output), expect=bench_mark)
    elif bias == 1:
        output = utils.mod_launch(mod, (m_x, m_y, bias_data, output), expect=bench_mark)

    # compare result
    rtol, atol = get_rtol_atol("matmul", dtype)
    compare_result = compare_tensor(output, bench_mark, rtol=rtol, atol=atol, equal_nan=True)
    # compare_result = utils.result_compare(output, bench_mark, r_tol=5e-3)
    return (m_x, m_y), output, bench_mark, compare_result


W
wangzhuo325 已提交
346
def matmul_compile(shape_x, shape_y, bias, left_format, right_format, output_format, adj_x, adj_y, dtype, bias_dtype, out_dtype, kernel_name, attrs, tuning=False):
C
ckey_Dou 已提交
347 348 349 350 351 352 353
    batch_tuple, m, k, n = extract_dim(shape_x, shape_y, adj_x, adj_y)
    m = (m + 15) // 16 * 16
    n = (n + 15) // 16 * 16
    k = (k + 15) // 16 * 16
    shape_xx, shape_yy, bias_shape, out_shape, k = get_converted_shapes(m, n, k, batch_tuple, adj_x, adj_y, bias,
                                                                        left_format, right_format, output_format)
    input_shapes = [shape_xx, shape_yy, bias_shape]
W
wangzhuo325 已提交
354
    input_types = [dtype, dtype, bias_dtype]
C
ckey_Dou 已提交
355 356 357
    has_bias = False
    if bias == 1:
        has_bias = True
W
wangzhuo325 已提交
358
    op_attrs = [out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs]
C
ckey_Dou 已提交
359 360 361
    if has_bias == False:
        input_shapes = [shape_xx, shape_yy]
        input_types = [dtype, dtype]
W
wangzhuo325 已提交
362
        op_attrs = [None, out_dtype, left_format, right_format, output_format, adj_x, adj_y, attrs]
C
chenlei_autodiff 已提交
363
    return utils.op_build_test(matmul.matmul, input_shapes, input_types, op_attrs, kernel_name, attrs, tuning=tuning)