col2im_compute.py 2.6 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
import akg
import tvm
from akg import backend as cce
from akg.utils import kernel_exec as utils
from test_op.col2im import intrin_col2im


def col2im_manual_schedule(shape, kernel, stride, pad, dtype, output_H_W, polyhedral=True, attrs=None):
    """
    Col2im operation with manual schedule.

     Args:
        shape (Union[list, tuple]): seven int numbers for the input's image size.
        kernel (Union[list, tuple]): two int numbers for the sliding window's size.
        stride (Union[list, tuple]): two int numbers for the sliding window's stride.
        pad: (Union[list, tuple]): four int numbers for padding's sizes: top, bottom, left, and right
        dtype (str): parameters' type.
        output_H_W (Union[list, tuple]): two int numbers for the output's height and width.
        polyhedral (bool): If True, use auto-schedule, else use manual-schedule, default value is True.
        attrs (dict): Specifies parameters used in manual-schedule.

    Returns:
        tvm.tensor.Tensor as result for col2im operation.
    """

    N, C1, KH, KW, OH, OW, C0 = shape
    H, W = output_H_W
    output_shape = (N, C1, H, W, C0)
    kernel_h, kernel_w = kernel
    stride_h, stride_w = stride
    pad_t, pad_b, pad_l, pad_r = pad

    assert H == (OH - 1) * stride_h + kernel_h - (pad_t + pad_b), "Height of input and output do not match"
    assert W == (OW - 1) * stride_w + kernel_w - (pad_l + pad_r), "Width of input and output do not match"

    col2im = intrin_col2im(shape, output_shape, kernel, stride, pad, dtype)

    # tensor for the input data
    data = tvm.placeholder(shape, dtype, name="input_data")

    # assume we need the whole width of A
    # choose a section of the rows of A that encompasses all of the windows in the current window-batch
    res = tvm.compute(
        output_shape,
        lambda b, c1, h, w, c0:
            data(b, c1, h % KH, w % KW, h % OH, w % OW, c0),
        name="col2im_intrinsic"
    )

    # schedule for differetiation operation
    s = tvm.create_schedule([res.op])

    res_ub = s.cache_write(res, "local.UB")
    data_ub = s.cache_read(data, "local.UB", [res_ub])

    b, c1, h, w, c0 = res.op.axis

    s[data_ub].compute_at(s[res], c1)
    s[res_ub].compute_at(s[res], c1)

    s[res_ub].tensorize(res_ub.op.axis[0], col2im)

    with akg.build_config(add_lower_pass=cce.debug_mode(0), dump_pass_ir=True):
        mod = akg.build(s, [data, res], "cce", name="col2im_manual_schedule", attrs=attrs, polyhedral=polyhedral)
        source_code = mod.imported_modules[0].get_source()
        kernel_name = "col2im_manual_schedule"
67
        utils.create_code(kernel_name, "./", source_code)
C
ckey_Dou 已提交
68
    return mod