implicit_gemm_int8_nchw4_dp4a.cpp 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/**
 * \file dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./algo.h"
#include "src/cuda/utils.h"
14 15
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

using namespace megdnn;
using namespace cuda;

bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available(
        const SizeArgs& args) const {
    if (args.bias_layout->ndim <= 0)
        return false;

    using Param = param::ConvBias;
    using Format = Param::Format;
    using Sparse = Param::Sparse;
    using Mode = Param::Mode;
    bool available = true;
    auto&& param = args.opr->param();
    auto&& fm = args.filter_meta;
    if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout),
                                                param.format))
        return false;
    if (param.format != Format::NCHW4)
        return false;
    UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout),
                                 param);
    // TODO support group conv
    available &= param.sparse == Sparse::DENSE;
    // mode must be cross correlation
    available &= param.mode == Mode::CROSS_CORRELATION;
    // check data type
    auto src_dtype = args.src_layout->dtype,
         filter_dtype = args.filter_layout->dtype,
         bias_dtype = args.bias_layout->dtype,
         dst_dtype = args.dst_layout->dtype;
    available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
                  filter_dtype.enumv() == DTypeEnum::QuantizedS8 &&
                  bias_dtype.enumv() == DTypeEnum::QuantizedS32 &&
                  dst_dtype.enumv() == DTypeEnum::QuantizedS8);
    // TODO: support dialtion
    available &= dh == 1 && dw == 1;
    // only support sm_61 or later, platform should have fast native int8
    // support
    available &= is_compute_capability_required(6, 1);
57 58
    // FIXME: too large filter size is not supported now 
    available &= fh * fw <= 49;
59 60 61 62 63 64
    return available;
}

WorkspaceBundle
ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_workspace_bundle(
        dt_byte* raw_ptr, const SizeArgs& args) const {
65 66
    size_t ws_filter = args.filter_layout->span().dist_byte();
    return WorkspaceBundle{raw_ptr, {ws_filter}};
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
}

size_t
ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_workspace_in_bytes(
        const SizeArgs& args) const {
    return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec(
        const ExecArgs& args) const {
    using Format = Param::Format;
    auto&& param = args.opr->param();
    auto&& fm = args.filter_meta;
    UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout),
                                 param);
    auto ws = get_workspace_bundle(args.workspace.raw_ptr, args);
83
    auto ws_filter = ws.get(0);
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    auto&& stream = cuda_stream(args.opr->handle());

    // reformat filter from nchw4 to chwn4
    {
        TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()};
        src.init_contiguous_stride();
        TensorLayout dst = src;
        dst.stride[0] = 1, dst.stride[1] = dst[0];
        TensorND ts_src, ts_dst;
        ts_src.raw_ptr = args.filter_tensor->raw_ptr;
        ts_src.layout = src;
        ts_dst.raw_ptr = ws_filter;
        ts_dst.layout = dst;
        auto&& transpose =
                args.opr->handle()->create_operator<RelayoutForward>();
        transpose->exec(ts_src, ts_dst);
    }

    convolution::ConvParam kern_param;
    kern_param.n = n, kern_param.co = co, kern_param.ci = ci,
    kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho,
    kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw,
    kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh,
    kern_param.fw = fw;

    float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale,
          filter_scale =
                  args.filter_layout->dtype.param<dtype::QuantizedS8>().scale,
          bias_scale =
                  args.bias_layout->dtype.param<dtype::QuantizedS32>().scale,
          dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale;
    float alpha = src_scale * filter_scale / dst_scale,
          beta = bias_scale / dst_scale;
    int8_t* z_dev_ptr = nullptr;
118
    float gamma = 0.0;
119
    if (args.z_layout->ndim > 0) {
120
        z_dev_ptr = args.z_tensor->compatible_ptr<int8_t>();
121 122 123
        float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale;
        gamma = z_scale / dst_scale;
    }
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
    uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode);
    if (fh == 1 && fw == 1) {
        cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<false>(
                args.src_tensor->compatible_ptr<int8_t>(),
                reinterpret_cast<int8_t*>(ws_filter),
                args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
                args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param,
                nonlinear_mode, alpha, beta, gamma, dst_scale,
                cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
                                           m_algo_param.threadblock_n,
                                           m_algo_param.threadblock_k},
                cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
                                           m_algo_param.warp_n,
                                           m_algo_param.warp_k},
                stream);
    } else {
        cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4<true>(
                args.src_tensor->compatible_ptr<int8_t>(),
                reinterpret_cast<int8_t*>(ws_filter),
                args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr,
                args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param,
                nonlinear_mode, alpha, beta, gamma, dst_scale,
                cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m,
                                           m_algo_param.threadblock_n,
                                           m_algo_param.threadblock_k},
                cutlass_wrapper::GemmCoord{m_algo_param.warp_m,
                                           m_algo_param.warp_n,
                                           m_algo_param.warp_k},
                stream);
153 154 155 156
    }
}

// vim: syntax=cpp.doxygen