gather.py 5.5 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""operator dsl function: gather"""
import akg.tvm
import akg
from akg import backend as cce
from akg.utils import kernel_exec as utils
from akg.utils import validation_check as vc_util

def new_alloc(ib, dtype, shape, name, scope):
    """allocate buffer"""
    buf_var = ib.allocate(dtype, shape, name=name, scope=scope)
    new_buffer = akg.tvm.decl_buffer(shape, buf_var.dtype, name=name, scope=scope, data=buf_var)
    return new_buffer


def kernel_ir(dst, data, indices):
    """build ir"""
    ib = akg.tvm.ir_builder.create()

    # copy indices to UB
    indices_ptr = ib.buffer_ptr(indices)
    batch_size = 1024
    batch_num = indices.shape[0] // batch_size
    last_size = indices.shape[0] % batch_size
    burst_len_of_batch_size = (batch_size + 7) // 8
    burst_len_of_last_size = (last_size + 7) // 8

    data_ub = new_alloc(ib, data.dtype, (data.shape[1]), "X_UB", scope=cce.scope_ubuf)
    indices_ub = new_alloc(ib, indices_ptr.dtype, (batch_size,), "Y_UB", scope=cce.scope_ubuf)
    row_len = data.shape[1]
    burst_len = (row_len + 15) // 16

    with ib.if_scope(batch_num > 0):
        with ib.for_range(0, batch_num, name='num') as num:
            ib.emit(akg.tvm.call_extern(indices.dtype, "copy_gm_to_ubuf",
                                    indices_ub.access_ptr("w"),
                                    indices.access_ptr('r', offset=num * batch_size),
                                    0, 1, burst_len_of_batch_size, 0, 0))

            with ib.for_range(0, batch_size, name='row') as row:
                reg = ib.allocate("int32", (1,), name='reg', scope=cce.scope_reg)
                ib.emit(akg.tvm.call_extern(
                    indices.dtype, "reg_mov",
                    akg.tvm.call_extern(reg.dtype, "reg", reg[0]),
                    indices_ub.access_ptr('r', offset=row)
                ))
                gm_offset = reg[0] * row_len
                ib.emit(akg.tvm.call_extern(data.dtype, "copy_gm_to_ubuf",
                                        data_ub.access_ptr("w"),
                                        data.access_ptr('r', offset=gm_offset),
                                        0, 1, burst_len, 0, 0))
                ib.emit(akg.tvm.call_extern(dst.dtype, "copy_ubuf_to_gm",
                                        dst.access_ptr('w', offset=(num * batch_size + row) * row_len),
                                        data_ub.access_ptr("r"),
                                        0, 1, burst_len, 0, 0))

    with ib.if_scope(last_size > 0):
        ib.emit(akg.tvm.call_extern(indices.dtype, "copy_gm_to_ubuf",
                                indices_ub.access_ptr("w"),
                                indices.access_ptr('r', offset=batch_num * batch_size),
                                0, 1, burst_len_of_last_size, 0, 0))

        with ib.for_range(0, last_size, name='row') as row:
            reg = ib.allocate("int32", (1,), name='reg', scope=cce.scope_reg)
            ib.emit(akg.tvm.call_extern(
                indices.dtype, "reg_mov",
                akg.tvm.call_extern(reg.dtype, "reg", reg[0]),
                indices_ub.access_ptr('r', offset=row)
            ))
            gm_offset = reg[0] * row_len
            ib.emit(akg.tvm.call_extern(data.dtype, "copy_gm_to_ubuf",
                                    data_ub.access_ptr("w"),
                                    data.access_ptr('r', offset=gm_offset),
                                    0, 1, burst_len, 0, 0))
            ib.emit(akg.tvm.call_extern(dst.dtype, "copy_ubuf_to_gm",
                                    dst.access_ptr('w', offset=(batch_num * batch_size + row) * row_len),
                                    data_ub.access_ptr("r"),
                                    0, 1, burst_len, 0, 0))

    return ib.get()

@vc_util.check_input_type((list, tuple), (list, tuple), str, str, int, str, (str, type(None)))
def gather(params_shape, indices_shape, params_dtype, indices_dtype, axis, kernel_name, cce_path="./"):
    """Gather data by indices"""
    vc_util.check_shape(params_shape, length=2)
    vc_util.check_shape(indices_shape, length=1)
    vc_util.ops_dtype_check(params_dtype, vc_util.DtypeForDavinci.ALL_TYPES)
    vc_util.ops_dtype_check(indices_dtype, vc_util.DtypeForDavinci.INT32)
    vc_util.check_equal("axis", "zero", axis, 0)

    # construct compute
    o_shape = (indices_shape[0], params_shape[1])
    xx = akg.tvm.placeholder(params_shape, dtype=params_dtype, name="X")
    yy = akg.tvm.placeholder(indices_shape, dtype=indices_dtype, name="Y")
    res = akg.tvm.extern(o_shape, [xx, yy], lambda ins, outs: kernel_ir(outs[0], ins[0], ins[1]),
        name="res", dtype=params_dtype)
    s = akg.tvm.create_schedule(res.op)

    # create cce
    attrs = {"enable_multicore": False}
    with akg.build_config(add_lower_pass=cce.debug_mode(0), dump_pass_ir=True):
        mod = akg.build(s, [xx, yy, res], "cce", name=kernel_name, attrs=attrs)

    source_code = mod.imported_modules[0].get_source()
    utils.create_cce(kernel_name, cce_path, source_code)

    return mod