kern.cuh 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
/**
 * \file dnn/src/cuda/elemwise_helper.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#pragma once

#include "src/cuda/elemwise_helper.cuh"
#include "src/cuda/utils.cuh"

#if MEGDNN_CC_HOST
#include "megdnn/oprs.h"
#endif

namespace megdnn {
namespace cuda {

template <typename ctype>
struct FakeQuantKernOp {
    ctype* input;
    ctype* output;
    ctype qmin, qmax;

    __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) {
        ctype x = round(input[idx] / scale) + zero_point;
        x = fmaxf(fminf(x, qmax), qmin);
        output[idx] = (x - zero_point) * scale;
    }

#if MEGDNN_CC_HOST
    FakeQuantKernOp(const TensorND& input, const TensorND& output,
                    const FakeQuant::Param& param)
            : input{input.ptr<ctype>()},
              output{output.ptr<ctype>()},
              qmin(param.qmin),
              qmax(param.qmax) {}
#endif
};

template <typename ctype>
struct FakeQuantBwdKernOp {
    ctype* diff;
    ctype* input;
    ctype* grad;
    ctype qmin, qmax;

    __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) {
        ctype x = round(input[idx] / scale) + zero_point;
        grad[idx] = x <= qmax && x >= qmin ? diff[idx] : 0.0;
    }

#if MEGDNN_CC_HOST
    FakeQuantBwdKernOp(const TensorND& diff, const TensorND& input,
                       const TensorND& grad, const FakeQuant::Param& param)
            : diff{diff.ptr<ctype>()},
              input{input.ptr<ctype>()},
              grad{grad.ptr<ctype>()},
              qmin(param.qmin),
              qmax(param.qmax) {}
#endif
};

template <typename ctype>
struct FakeQuantKernOpNonContig {
    ctype qmin;
    ctype qmax;

    __device__ void operator()(uint32_t, ctype& output, ctype input,
                               ctype scale, ctype zero_point) {
        ctype x = round(input / scale) + zero_point;
        x = fmaxf(fminf(x, qmax), qmin);
        output = (x - zero_point) * scale;
    }

#if MEGDNN_CC_HOST
    FakeQuantKernOpNonContig(const FakeQuant::Param& param)
            : qmin(param.qmin), qmax(param.qmax) {}
#endif
};

template <typename ctype>
struct FakeQuantBwdKernOpNonContig {
    ctype qmin;
    ctype qmax;

    __device__ void operator()(uint32_t, ctype& grad, ctype diff, ctype input,
                               ctype scale, ctype zero_point) {
        ctype x = round(input / scale) + zero_point;
        grad = x <= qmax && x >= qmin ? diff : 0.0;
    }

#if MEGDNN_CC_HOST
    FakeQuantBwdKernOpNonContig(const FakeQuant::Param& param)
            : qmin(param.qmin), qmax(param.qmax) {}
#endif
};

}  // namespace cuda
}  // namespace megdnn