maxpool_ad_run.py 6.6 KB
Newer Older
C
ckey_Dou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorio import compare_tensor
import numpy as np
from akg.utils import kernel_exec as utils
from test_run import maxpool_grad_run
from akg.ops.nn.maxpool_ad import maxpool_ad
from akg.ops.nn.maxpool_ad import maxpool_ad_manual_schedule_all_max
from akg.ops.nn.maxpool_ad import maxpool_ad_no_custom_diff_manual_schedule_all_max
from akg.ops.nn.maxpool_ad import maxpool_ad_no_custom_diff_poly_all_max
from akg.utils.dsl_create import cal_pad_shapes_by_strategy
import itertools
from base import get_rtol_atol
from gen_random import random_gaussian


def maxpool_ad_run(shape, kernel, stride, pad, dtype, optimized, polyhedral=False, first_max=True, attrs=None):
    expect, head, input, output, forward, mask = gen_data(dtype, kernel, pad, shape, stride, first_max)
    if 'tuning' in attrs.keys():
        t = attrs.get("tuning", False)
        kernel_name = attrs.get("kernel_name", False)

    if polyhedral:
        if optimized:
            if first_max:
                raise Exception("ERROR: no DSL with poly support for first_max")
            else:
                raise Exception("ERROR: no DSL with poly support for all_max")
        else:
            if first_max:
                raise Exception("ERROR: no AD with poly support for first_max")
            else:
                mod = utils.op_build_test(maxpool_ad_no_custom_diff_poly_all_max, [head.shape, shape],
                                [dtype, dtype], kernel_name="maxpool_ad_no_custom_diff_poly_all_max",
47
                                op_attrs=[kernel, stride, pad], attrs=attrs, log_cce=False, dump_code=True, polyhedral=polyhedral)
C
ckey_Dou 已提交
48 49 50 51 52 53
                output = utils.mod_launch(mod, [head, input, output], expect=expect)
    else:
        if optimized:
            if first_max:
                mod = utils.op_build_test(maxpool_ad, [head.shape, shape, forward.shape, mask.shape],
                                        [dtype, dtype, dtype, dtype], kernel_name="maxpool_ad_first_max",
54
                                        op_attrs=[kernel, stride, pad], attrs=attrs, log_cce=False, dump_code=True, polyhedral=polyhedral)
C
ckey_Dou 已提交
55 56 57 58 59 60 61 62 63 64
                output = utils.mod_launch(mod, [head, input, forward, mask, output], expect=expect)
            else:
                mod = maxpool_ad_manual_schedule_all_max(shape, kernel, stride, pad, dtype, attrs=attrs, polyhedral=polyhedral)
                output = utils.mod_launch(mod, [head, input, forward, output], expect=expect)
        else:
            if first_max:
                raise Exception("ERROR: no AD with mansch support for first_max")
            else:
                mod = utils.op_build_test(maxpool_ad_no_custom_diff_manual_schedule_all_max, [head.shape, shape],
                                [dtype, dtype], kernel_name="maxpool_ad_no_custom_diff_manual_schedule_all_max",
65
                                op_attrs=[kernel, stride, pad], attrs=attrs, log_cce=False, dump_code=True, polyhedral=polyhedral)
C
ckey_Dou 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
                output = utils.mod_launch(mod, [head, input, output], expect=expect)
    
    if 'tuning' in attrs.keys():
        if t:
            return mod, expect, (head, input, output)
        else:
            return mod

    rtol, atol = get_rtol_atol("maxpool_grad", dtype)
    return [head, input], output, expect, compare_tensor(output, expect, rtol=rtol, atol=atol, equal_nan=True)

def benchmark(x, y, dy, kernel, stride, pad, behaviour=0):

    kernel_h, kernel_w = kernel
    stride_h, stride_w = stride
    [pad_h_head, pad_h_tail, pad_w_head, pad_w_tail], _ = cal_pad_shapes_by_strategy(x.shape, kernel, stride, pad)
    N, C1, H, W, C0 = x.shape
    pad_shape = (N, C1, H + pad_h_tail + pad_h_head, W + pad_w_tail + pad_w_head, C0)

    padx = np.full(pad_shape, 0, dtype=x.dtype)
    padx[:, :, pad_h_head:(pad_h_head + H), pad_w_head:(pad_w_head + W), :] = x

    dx = np.zeros(padx.shape, dtype=x.dtype)
    _, _, yH, yW, _ = y.shape
    mask = np.zeros((N, C1, kernel_h, kernel_w, yH, yW, C0))

    if behaviour == 0:
        for n in range(N):
            for c1 in range(C1):
                for yh in range(yH):
                    for yw in range(yW):
                        for c0 in range(C0):
                            out_maxpool1 = y[n, c1, yh, yw, c0]
                            head_maxpool1 = dy[n, c1, yh, yw, c0]
                            for kh,kw in itertools.product(range(kernel_h), range(kernel_w)):
                                    if padx[n, c1, yh*stride_h + kh, yw*stride_w + kw, c0] == out_maxpool1:
                                        dx[n, c1, yh*stride_h + kh, yw*stride_w + kw, c0] += head_maxpool1
                                        mask[n, c1, kh, kw, yh, yw, c0] = 1.0
                                        break
    elif behaviour == 1:
        for n in range(N):
            for c1 in range(C1):
                for yh in range(yH):
                    for yw in range(yW):
                        for c0 in range(C0):
                            out_maxpool1 = y[n, c1, yh, yw, c0]
                            head_maxpool1 = dy[n, c1, yh, yw, c0]
                            for kh in range(kernel_h):
                                for kw in range(kernel_w):
                                    if padx[n, c1, yh*stride_h + kh, yw*stride_w + kw, c0] == out_maxpool1:
                                        dx[n, c1, yh*stride_h + kh, yw*stride_w + kw, c0] += head_maxpool1

    return dx[:, :, pad_h_head:(pad_h_head + H), pad_w_head:(pad_w_head + W), :], mask


def gen_data(dtype, kernel, pad, shape, stride, first_max):
    if first_max:
        behaviour = 0
    else:
        behaviour = 1
    support_list = {"float16": np.float16, "float32": np.float32, "int32": np.int32}
    input = random_gaussian(shape, miu=1, sigma=0.1).astype(dtype)
    y = maxpool_grad_run.maxpool_benchmark(input, kernel, stride, pad).astype(dtype)
    head = random_gaussian(y.shape, miu=1, sigma=0.1).astype(dtype)
    expect, mask = benchmark(input, y, head, kernel, stride, pad, behaviour)
    expect = expect.astype(dtype)
    mask = mask.astype(dtype)
    out_shape = expect.shape
    output = np.full(out_shape, 0.0, dtype)
    return expect, head, input, output, y, mask