opr_impl.cpp 11.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11
 */
12
#include "megdnn/dtype.h"
13
#include "src/cuda/conv_bias/algo.h"
14 15
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/opr_impl.h"
16 17 18
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"

19
#include "src/common/conv_bias.h"
20 21 22 23 24 25 26 27 28 29
#include "src/common/algo_chooser.h"

#include "src/cuda/cudnn_with_check.h"

namespace megdnn {
namespace cuda {

void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
                               _megdnn_tensor_in bias, _megdnn_tensor_in z,
                               _megdnn_tensor_out dst,
30
                               const PreprocessedFilter* preprocessed_filter,
31
                               _megdnn_workspace workspace) {
32 33 34
    check_exec_allow_noncontiguous(src.layout, filter.layout, bias.layout,
                                   z.layout, dst.layout, workspace.size,
                                   preprocessed_filter);
M
Megvii Engine Team 已提交
35 36
    AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace,
                            preprocessed_filter);
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout,
                              z.layout, dst.layout);
    algo->check_workspace(args, workspace).exec(args);
};

std::vector<ConvBiasForward::Algorithm*>
ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src,
                                        const TensorLayout& filter,
                                        const TensorLayout& bias,
                                        const TensorLayout& z,
                                        const TensorLayout& dst) {
    return megdnn::get_all_algorithms<ConvBiasForwardImpl>(
            {this, src, filter, bias, z, dst});
}

ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
        const TensorLayout& src, const TensorLayout& filter,
        const TensorLayout& bias, const TensorLayout& z,
        const TensorLayout& dst, size_t workspace_limit_in_bytes,
56 57
        const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
58 59 60 61 62 63 64 65 66 67 68 69
    using namespace conv_bias;
    AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
    auto dst_layout = *args.dst_layout;
    if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
        dst_layout.dtype = DType();
        args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype,
                                            args.filter_layout->dtype,
                                            dst_layout.dtype);
    }
    auto conv_args = args;

    auto cudnn_conv_bias_act_from_enum_wrapper =
70
            [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
71 72 73 74
        return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo);
    };

    auto cudnn_conv_from_enum_wrapper =
75
            [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
76 77 78 79
        return sm_algo_pack.cudnn_conv_from_enum(algo);
    };

    auto get_cudnn_algo =
80 81
            [this, &conv_args, &args, workspace_limit_in_bytes, positive_attr,
             negative_attr](
82 83 84 85 86 87
                    const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>&
                            cb) -> AlgoBase* {
        auto cudnn_handle = cuda::cudnn_handle(this->handle());
        CUDNNForwardDescs desc;
        conv_args.init_conv_desc(desc);
#if CUDNN_MAJOR >= 7
88
        auto& cudnn = static_cast<HandleImpl*>(this->handle())->cudnn();
89
        int max_count = 0;
90
        cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle,
91 92 93
                                                                &max_count));
        SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count);
        int ret_count = 0;
94
        cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7(
95 96 97 98 99
                cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
                desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count,
                &ret_count, algo_perf.data()));
        for (int i = 0; i < ret_count; ++i) {
            auto conv_bias_algo = cb(algo_perf[i].algo);
100
            if (conv_bias_algo->is_available_attribute(
101
                        args, positive_attr, negative_attr,
102
                        workspace_limit_in_bytes)) {
103
                return conv_bias_algo;
104
            }
105 106 107 108 109 110 111 112 113 114
        }
#else
        cudnnConvolutionFwdAlgo_t algo;
        cudnn_check(cudnnGetConvolutionForwardAlgorithm(
                cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
                desc.conv_desc.conv_desc, desc.dst_desc.desc,
                CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
                workspace_limit_in_bytes, &algo));

        auto conv_bias_algo = cb(algo);
115 116
        if (conv_bias_algo->is_available_attribute(args, positive_attr,
                                                   negative_attr,
117
                                                   workspace_limit_in_bytes))
118 119 120 121 122
            return conv_bias_algo;
#endif
        return nullptr;
    };

123 124
    auto get_1x1_algo = [workspace_limit_in_bytes, positive_attr,
                         negative_attr](const AlgoBase::SizeArgs& size_arg)
125
            -> ConvBiasForwardImpl::AlgoBase* {
126
        if (sm_algo_pack.batched_matmul.is_available_attribute(
127 128
                    size_arg, positive_attr, negative_attr,
                    workspace_limit_in_bytes)) {
129 130 131 132 133
            return &sm_algo_pack.batched_matmul;
        }
        return nullptr;
    };

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    const bool is_chanwise =
            (args.filter_meta.format == Param::Format::NCHW &&
             args.filter_meta.group == src[1]) ||
            (args.filter_meta.format == Param::Format::NCHW4 &&
             args.filter_meta.group == src[1] * 4) ||
            (args.filter_meta.format == Param::Format::NCHW32 &&
             args.filter_meta.group == src[1] * 32);
    // prefer special chanwise impl since as the group conv of cudnn
    // whose version is lower than v7.5.0 is still slower than our
    // implementation in many channel-wise cases
    const bool slow_cudnn_chanwise_impl =
            CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5);
    //! choose CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM default for large image
    const int hw_size = src[2] * src[3];
    //! choose dnn when stride != 1, may need calibrate for different cudnn
    //! version
    const bool prefer_dnn_chanwise =
            slow_cudnn_chanwise_impl || args.filter_meta.stride[0] != 1 ||
            args.filter_meta.stride[1] != 1 || hw_size < 512;
    //! avoid bad case in cudnn, check dnn chanwise impl first
    if (is_chanwise) {
        if (prefer_dnn_chanwise) {
156
            if (sm_algo_pack.chanwise.is_available_attribute(
157 158
                        args, positive_attr, negative_attr,
                        workspace_limit_in_bytes))
159
                return &sm_algo_pack.chanwise;
160
            if (sm_algo_pack.chanwise8x8x32.is_available_attribute(
161 162
                        args, positive_attr, negative_attr,
                        workspace_limit_in_bytes))
163 164 165 166 167 168 169 170 171 172 173
                return &sm_algo_pack.chanwise8x8x32;
        } else {
            conv_args.dst_layout = &dst_layout;
            if (is_cudnn_supported(conv_args)) {
                if (auto algo = get_cudnn_algo(cudnn_conv_from_enum_wrapper)) {
                    return algo;
                }
            }
        }
    }

174 175 176
    //! Prefer CUDNN CONVBIAS.
    bool cudnn_conv_bias_act_supported = false;
    for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) {
177 178
        if (algo.is_available_attribute(args, positive_attr, negative_attr,
                                        workspace_limit_in_bytes)) {
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
            cudnn_conv_bias_act_supported = true;
            break;
        }
    }

    if (cudnn_conv_bias_act_supported) {
        if (auto algo = get_cudnn_algo(cudnn_conv_bias_act_from_enum_wrapper))
            return algo;
    }

    // modify conv_args dst_layout
    conv_args.dst_layout = &dst_layout;
    if (is_cudnn_supported(conv_args)) {
        if (auto algo = get_cudnn_algo(cudnn_conv_from_enum_wrapper))
            return algo;
    }

196 197 198 199
    if (auto algo = get_1x1_algo(args)) {
        return algo;
    }

200 201 202 203
    if (args.filter_meta.group > 1 &&
        sm_algo_pack.group.is_available_attribute(
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
        return &sm_algo_pack.group;
204 205
    }

206
    if (sm_algo_pack.fallback_nchw_qs8.is_available_attribute(
207
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
208 209 210
        return &sm_algo_pack.fallback_nchw_qs8;
    }

211
    if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
212 213 214
        return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>(
                sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
                "cuda convbias fwd", positive_attr, negative_attr);
215
    } else {
216 217 218
        return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>(
                sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
                "cuda convbias fwd", positive_attr, negative_attr);
219 220 221 222 223 224 225
    }
}

const char* ConvBiasForwardImpl::get_algorithm_set_name() const {
    return "CONV_BIAS_CUDA";
}

M
Megvii Engine Team 已提交
226 227 228 229 230 231 232
size_t ConvBiasForwardImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& filter,
        const TensorLayout& bias, const TensorLayout& z,
        const TensorLayout& dst,
        const PreprocessedFilter* preprocessed_filter) {
    AlgoBase::SizeArgs args{
            this, src, filter, bias, z, dst, preprocessed_filter};
233 234 235 236
    return get_algorithm(this, src, filter, bias, z, dst)
            ->get_workspace_in_bytes(args);
};

M
Megvii Engine Team 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
size_t ConvBiasForwardImpl::get_preprocess_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& filter,
        const TensorLayout& bias, const TensorLayout& z,
        const TensorLayout& dst) {
    AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
    return get_algorithm(this, src, filter, bias, z, dst)
            ->get_preprocess_workspace_in_bytes(args);
}

SmallVector<TensorLayout>
ConvBiasForwardImpl::deduce_preprocessed_filter_layout(
        const TensorLayout& src, const TensorLayout& filter,
        const TensorLayout& bias, const TensorLayout& z,
        const TensorLayout& dst) {
    AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
    return get_algorithm(this, src, filter, bias, z, dst)
            ->deduce_preprocessed_filter_layout(args);
}

void ConvBiasForwardImpl::exec_preprocess(
        const TensorLayout& src_layout, _megdnn_tensor_in filter,
258
        _megdnn_tensor_in bias, const TensorLayout& z_layout,
M
Megvii Engine Team 已提交
259 260 261
        const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
        _megdnn_workspace workspace) {
    TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout},
262
            z{nullptr, z_layout};
M
Megvii Engine Team 已提交
263 264 265 266 267 268 269
    AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace,
                            preprocessed_filter);
    auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout,
                              z.layout, dst.layout);
    return algo->exec_preprocess(args);
}

270 271 272 273
}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cpp.doxygen