cudnn_conv_bias_activation.cpp 12.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megdnn/oprs/general.h"

#include "./algo.h"

M
Megvii Engine Team 已提交
16
#include "src/common/conv_bias.h"
17 18 19 20 21 22 23 24 25 26
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
        const SizeArgs& args) const {
27 28
    if (args.filter_meta.format != Param::Format::NCHW &&
        args.filter_meta.format != Param::Format::NHWC) {
M
Megvii Engine Team 已提交
29
        if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
30 31
            return false;
        }
32
    }
33 34
    if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
         args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
35 36
        args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)
        return false;
37 38 39
    if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
        args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm)
        return false;
40 41 42 43
    if (args.src_layout->dtype == args.filter_layout->dtype &&
        args.src_layout->dtype == dtype::BFloat16()) {
        return false;
    }
44

45
    if (args.bias_layout->ndim == 0 ||
M
Megvii Engine Team 已提交
46
        !check_bias_share_in_channel(*(args.bias_layout), args.opr->param().format)) {
47
        return false;
48
    }
49
    auto&& param = args.opr->param();
50 51 52 53 54 55 56 57 58 59 60

#if (CUDNN_MAJOR == 8 && CUDNN_MINOR < 2)
    if (m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM &&
        param.format == param::ConvBias::Format::NCHW4 &&
        args.filter_meta.group * args.filter_meta.ocpg > 256 &&
        args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
        args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS8) {
        return false;
    }
#endif

61 62 63 64 65 66 67 68 69
    // FIXME: cudnn cannot handle the case when the initial value of dst tensor
    // contains nan and beta is zero, because the result of 0.f * nan is still
    // nan
    if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
        args.dst_layout->dtype.enumv() == DTypeEnum::Float32 &&
        param.format == param::ConvBias::Format::NCHW) {
        return false;
    }

70 71 72 73 74 75
    if (args.src_layout->dtype.enumv() == DTypeEnum::Float16 &&
        args.dst_layout->dtype.enumv() == DTypeEnum::Float16 &&
        param.format == param::ConvBias::Format::NHWC) {
        return false;
    }

76 77 78 79 80 81
#if CUDNN_MAJOR < 8
    if (m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
        param.format == param::ConvBias::Format::NCHW4_NCHW)
        return false;
#endif
    if (param.format == param::ConvBias::Format::NCHW4_NCHW32 ||
82 83
        param.format == param::ConvBias::Format::NCHW32_NCHW4)
        return false;
84 85 86 87 88 89 90 91 92 93 94 95
    if (param.format == param::ConvBias::Format::NCHW &&
        (param.dilate_h != 1 || param.dilate_w != 1) &&
        m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
        auto&& device_prop = current_device_prop();
        // Dilated convbias in NCHW format produces wrong result on Pascal
        // Architecture, so we disable the algo here.
        if (device_prop.major == 6) {
            return false;
        }
    }

    if (param.format == param::ConvBias::Format::NCHW8 ||
96
        param.format == param::ConvBias::Format::NCHW64 ||
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        param.format == param::ConvBias::Format::CHWN4)
        return false;
    if (param.format == param::ConvBias::Format::NCHW32) {
        auto&& filter_meta = args.filter_meta;
        // NCHW32 layout only support group = 1
        if (filter_meta.group != 1)
            return false;
        // The data type (CUDNN_DATA_INT8x32) can only be used with algo
        // "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", for details, see
        // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html
        if (m_cudnn_enum != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM)
            return false;
        // check cudnn version
        if (CUDNN_VERSION < 7500)
            return false;
        // sm version
        auto&& device_prop = current_device_prop();
M
Megvii Engine Team 已提交
114
        if (device_prop.major < 7 || (device_prop.major == 7 && device_prop.minor < 5))
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
            return false;
    }

    CUDNNForwardDescs D;

    if (CUDNN_VERSION < 7401)
        return false;

    args.init_conv_bias_desc(D);
    switch (args.nonlinear_mode) {
        case param::ConvBias::NonlineMode::RELU:
            break;
        case param::ConvBias::NonlineMode::SIGMOID:
            // forbits sigmoid for quantized
            if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED)
                return false;
            MEGDNN_FALLTHRU  // XXX: why?
M
Megvii Engine Team 已提交
132 133 134 135
                    case param::ConvBias::NonlineMode::IDENTITY
                    : if (args.src_layout->dtype.category() ==
                          DTypeCategory::QUANTIZED) break;
            if (m_cudnn_enum != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
136 137 138 139 140 141 142
                // cudnn require algo to
                // CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
                // when activation if IDENTITY
                return false;
            }
            break;
        case param::ConvBias::NonlineMode::H_SWISH:
143 144
            if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED)
                break;
145 146
            return false;
        default:
M
Megvii Engine Team 已提交
147
            megdnn_throw("unsupported NonlineMode");
148 149
    }
    size_t workspace_size;
150
    auto status = cudnnGetConvolutionForwardWorkspaceSize(
151
            args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc,
M
Megvii Engine Team 已提交
152
            D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size);
153 154 155 156 157 158 159 160 161
    return status == CUDNN_STATUS_SUCCESS;
}

size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes(
        const SizeArgs& args) const {
    CUDNNForwardDescs D;

    args.init_conv_bias_desc(D);
    size_t workspace_size;
162
    auto status = cudnnGetConvolutionForwardWorkspaceSize(
163
            args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc,
M
Megvii Engine Team 已提交
164 165 166 167 168
            D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size);
    megdnn_assert(
            status == CUDNN_STATUS_SUCCESS,
            "conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status),
            args.to_string().c_str());
169 170 171 172 173 174 175 176 177 178 179 180
    if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
        args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
        // cudnn require bias to be float when executing CONFIG_INT
        // convert bias to float if bias is not float at first
        workspace_size += sizeof(float) * args.bias_layout->span().dist_elem();
    }
    return workspace_size;
}

void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec(
        const ExecArgs& args) const {
#if CUDNN_MAJOR < 7
M
Megvii Engine Team 已提交
181
    megdnn_throw("ConvBias require cudnn 7.0 or higher");
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
#else
    megdnn_assert(cudnnGetVersion() >= 7401);
    CUDNNForwardDescs D;
    args.init_conv_bias_desc(D);
    float alpha = 1.0f, beta = 0.0f;
    if (args.z_layout->ndim > 0)
        beta = 1.0f;

    auto get_scale = [](const DType& dtype) -> float {
        megdnn_assert(dtype.category() == DTypeCategory::QUANTIZED);
        switch (dtype.enumv()) {
#define cb(_dt)                  \
    case DTypeTrait<_dt>::enumv: \
        return dtype.param<_dt>().scale;
            MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
            default:
                megdnn_assert_internal(0);
        }
    };

M
Megvii Engine Team 已提交
203
    auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype,
204 205 206
         dst_dtype = args.dst_layout->dtype;
    megdnn_assert(
            (src_dtype.category() == dst_dtype.category()) ||
207
            (src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
208 209
             dst_dtype.enumv() == DTypeEnum::Float32));
    megdnn_assert(src_dtype.category() == filter_dtype.category());
210 211 212 213

    if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) {
        auto expected_bias_scale = get_scale(args.src_layout->dtype) *
                                   get_scale(args.filter_layout->dtype);
214 215 216 217 218
        alpha = expected_bias_scale;
        if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED)
            alpha /= get_scale(args.dst_layout->dtype);
        if (args.z_layout->ndim > 0 &&
            args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) {
M
Megvii Engine Team 已提交
219
            beta = get_scale(args.z_layout->dtype) / get_scale(args.dst_layout->dtype);
220 221
        }
        if (args.bias_layout->dtype.category() == DTypeCategory::QUANTIZED) {
M
Megvii Engine Team 已提交
222 223 224
            megdnn_assert(
                    fabs(expected_bias_scale - get_scale(args.bias_layout->dtype)) <
                    1e-4);
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        }
    }

    auto workspace_ptr = args.workspace.raw_ptr;
    auto workspace_size = args.workspace.size;
    auto bias_ptr = args.bias_tensor->raw_ptr;
    if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
        args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
        auto cvt = args.handle->create_operator<TypeCvt>();
        auto float_bias_layout = *args.bias_layout;
        auto converted_bias_layout = *args.bias_layout;
        converted_bias_layout.dtype = dtype::QuantizedS32(alpha);
        float_bias_layout.dtype = dtype::Float32();
        auto bias_size_in_bytes = float_bias_layout.span().dist_byte();
        megdnn_assert(args.workspace.size >= bias_size_in_bytes);
M
Megvii Engine Team 已提交
240 241 242
        cvt->exec(
                {args.bias_tensor->raw_ptr, converted_bias_layout},
                TensorND{workspace_ptr, float_bias_layout});
243 244 245 246 247 248 249 250 251 252 253

        bias_ptr = workspace_ptr;
        workspace_ptr += bias_size_in_bytes;
        workspace_size -= bias_size_in_bytes;
    }

    cudnnStatus_t status;
    if (args.z_layout->ndim == 0) {
        status = cudnnConvolutionBiasActivationForward(
                args.handle->cudnn_handle(), &alpha, D.src_desc.desc,
                args.src_tensor->raw_ptr, D.filter_desc.desc,
M
Megvii Engine Team 已提交
254 255 256 257
                args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum,
                workspace_ptr, workspace_size, &beta, D.dst_desc.desc,
                args.dst_tensor->raw_ptr, D.bias_desc.desc, bias_ptr,
                D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr);
258 259 260 261
    } else {
        status = cudnnConvolutionBiasActivationForward(
                args.handle->cudnn_handle(), &alpha, D.src_desc.desc,
                args.src_tensor->raw_ptr, D.filter_desc.desc,
M
Megvii Engine Team 已提交
262 263 264 265
                args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum,
                workspace_ptr, workspace_size, &beta, D.z_desc.desc,
                args.z_tensor->raw_ptr, D.bias_desc.desc, bias_ptr,
                D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr);
266 267
    }

M
Megvii Engine Team 已提交
268 269 270
    megdnn_assert(
            status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s, algo %s",
            cudnnGetErrorString(status), args.to_string().c_str(), name());
271 272 273 274 275
    // Noline
    switch (args.nonlinear_mode) {
        case param::ConvBias::NonlineMode::RELU:
            break;
        case param::ConvBias::NonlineMode::SIGMOID: {
M
Megvii Engine Team 已提交
276 277
            megdnn_assert(
                    args.dst_layout->dtype.category() != DTypeCategory::QUANTIZED);
278 279 280 281 282 283 284
            auto&& elem_opr = args.handle->create_operator<ElemwiseForward>();
            elem_opr->param().mode = Elemwise::Param::Mode::SIGMOID;
            elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
            break;
        }
        case param::ConvBias::NonlineMode::IDENTITY:
            break;
285
        case param::ConvBias::NonlineMode::H_SWISH: {
M
Megvii Engine Team 已提交
286 287 288 289
            megdnn_assert(
                    args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED ||
                    (args.dst_layout->dtype.category() == DTypeCategory::FLOAT &&
                     args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW));
290
            if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) {
M
Megvii Engine Team 已提交
291 292
                auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>();
                elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH;
293 294
                elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
            } else {
M
Megvii Engine Team 已提交
295
                auto&& elem_opr = args.handle->create_operator<ElemwiseForward>();
296 297 298
                elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH;
                elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
            }
299 300
            break;
        }
301
        default:
M
Megvii Engine Team 已提交
302
            megdnn_throw("unsupported NonlineMode");
303 304 305 306 307
    }
#endif
}

// vim: syntax=cpp.doxygen