opr_impl.cpp 15.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13
 */

#include "src/cuda/convolution/opr_impl.h"
14
#include "megdnn/dtype.h"
15
#include "src/common/algo_chooser.h"
16
#include "src/cuda/convolution/helper.h"
17
#include "src/cuda/convolution/forward/algos.h"
18 19 20 21 22 23 24 25 26 27 28 29
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution/backward_filter/algo.h"
#include "src/cuda/conv_bias/opr_impl.h"

#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace convolution;

#define TO_STRING2(v) #v
#define TO_STRING(v) TO_STRING2(v)
30 31 32
#define CUDNN_VERSION_STR  \
    TO_STRING(CUDNN_MAJOR) \
    "." TO_STRING(CUDNN_MINOR) "." TO_STRING(CUDNN_PATCHLEVEL)
33 34 35

/* ============== ConvolutionForwardImpl ============== */
ConvolutionForwardImpl::Algorithm*
36 37 38 39 40
ConvolutionForwardImpl::get_algorithm_heuristic(
        const TensorLayout& src, const TensorLayout& filter,
        const TensorLayout& dst, size_t workspace_limit_in_bytes,
        const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
41 42
    AlgoBase::SizeArgs args{this, src, filter, dst};
    MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
43 44
    MEGDNN_MARK_USED_VAR(positive_attr);
    MEGDNN_MARK_USED_VAR(negative_attr);
45
    return &sm_algo_pack.algo_default;
46 47
}

48 49 50 51
std::vector<ConvolutionForwardImpl::Algorithm*>
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
                                           const TensorLayout& filter,
                                           const TensorLayout& dst) {
52 53
    AlgoBase::SizeArgs args{this, src, filter, dst};
    return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args);
54 55 56 57
}

size_t ConvolutionForwardImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& filter,
58 59
        const TensorLayout& dst,
        const PreprocessedFilter* preprocessed_filter) {
60 61 62 63
    MEGDNN_MARK_USED_VAR(preprocessed_filter);
    AlgoBase::SizeArgs args{this, src, filter, dst};
    return megdnn::get_algorithm(this, src, filter, dst)
            ->get_workspace_in_bytes(args);
64 65 66 67 68
}

void ConvolutionForwardImpl::exec(_megdnn_tensor_in src,
                                  _megdnn_tensor_in filter,
                                  _megdnn_tensor_out dst,
69
                                  const PreprocessedFilter* preprocessed_filter,
70
                                  _megdnn_workspace workspace) {
71 72 73 74 75
    check_exec(src.layout, filter.layout, dst.layout, workspace.size,
               preprocessed_filter);
    AlgoBase::ExecArgs args(this, src, filter, dst, workspace);
    auto&& algo = get_algorithm(this, src.layout, filter.layout, dst.layout);
    algo->check_workspace(args, workspace).exec(args);
76 77 78
}

const char* ConvolutionForwardImpl::get_algorithm_set_name() const {
79
    return "CUDA CONVOLUTION_FORWARD";
80 81 82 83 84
}

/* ============== ConvolutionBackwardDataImpl ============== */

void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
85 86 87
                                       _megdnn_tensor_in diff,
                                       _megdnn_tensor_out grad,
                                       _megdnn_workspace workspace) {
88
    AlgoBase::ExecArgs args(this, filter, diff, grad, workspace);
89
    auto algo = get_algorithm(this, filter.layout, diff.layout, grad.layout);
90 91 92
    algo->check_workspace(args, workspace).exec(args);
}

93 94 95 96
std::vector<ConvolutionBackwardDataImpl::Algorithm*>
ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
                                                const TensorLayout& diff,
                                                const TensorLayout& grad) {
97 98 99 100 101 102 103 104
    return megdnn::get_all_algorithms<ConvolutionBackwardDataImpl>(
            {this, filter, diff, grad});
}

ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
        const TensorLayout& filter, const TensorLayout& diff,
        const TensorLayout& grad, size_t workspace_limit_in_bytes,
105 106
        const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
107
    auto fm = check_layout_fwd(grad, filter, diff);
108
    return get_algorithm_heuristic(filter, fm, diff, grad,
109 110
                                   workspace_limit_in_bytes, positive_attr,
                                   negative_attr);
111 112 113
}

ConvolutionBackwardDataImpl::Algorithm*
114 115 116
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
        const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
        const TensorLayout& grad, size_t workspace_limit_in_bytes,
117 118
        const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
119
    AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad);
120 121

    if (args.filter_meta.group > 1 &&
122
        sm_algo_pack.chanwise.is_available_attribute(
123
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
124 125 126 127
        // prefer special chanwise impl
        return &sm_algo_pack.chanwise;
    }

128 129
    if (args.filter_layout->dtype.enumv() ==
        DTypeTrait<dtype::QuantizedS8>::enumv) {
130 131 132
        return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
                sm_algo_pack.int8_algos, args, workspace_limit_in_bytes,
                "cuda conv bwd_data", positive_attr, negative_attr);
133 134
    }

135 136 137
    auto get_cudnn_algo =
            [this, &args, workspace_limit_in_bytes, positive_attr,
             negative_attr]() -> ConvolutionBackwardDataImpl::AlgoBase* {
138 139 140 141 142
        auto cudnn_handle = cuda::cudnn_handle(this->handle());
        CUDNNBwdDataDescs desc;
        args.init_desc(desc);

#if CUDNN_MAJOR >= 7
143
        MEGDNN_MARK_USED_VAR(negative_attr);
144
        auto& cudnn = args.handle->cudnn();
145
        int max_count = 0;
146
        cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount(
147 148 149
                cudnn_handle, &max_count));
        SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count);
        int ret_count = 0;
150
        cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7(
151 152 153 154 155 156
                cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc,
                desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count,
                algo_perf.data()));
        for (int i = 0; i < ret_count; ++i) {
            if (algo_perf[i].memory > workspace_limit_in_bytes)
                continue;
157
            if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) {
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
                if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
                    return reinterpret_cast<AlgoBase*>(
                            sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
                }
            } else {
                return reinterpret_cast<AlgoBase*>(
                        sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
            }
        }
        return nullptr;
#else
        cudnnConvolutionBwdDataAlgo_t algo;
        cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm(
                cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc,
                desc.conv_desc.desc, desc.grad_desc.desc,
                CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
                workspace_limit_in_bytes, &algo));
        auto&& cast_algo =
                reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo));
        return reinterpret_cast<AlgoBase*>(
178 179
                megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
                        cast_algo, positive_attr, negative_attr));
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
#endif
    };

    if (is_cudnn_supported(args.as_fwd_args())) {
        if (auto algo = get_cudnn_algo())
            return algo;
    }

    if (args.filter_meta.group > 1) {
        auto orig_args = args;
        TensorLayout a, b;
        AlgoGroupConvGeneral::modify_size_args(args, a, b);
        if (is_cudnn_supported(args.as_fwd_args())) {
            if (auto algo = get_cudnn_algo())
                return sm_algo_pack.algo2gconv.at(algo);
        }
        args = orig_args;
    }

199 200
    if (args.filter_layout->dtype.enumv() !=
        DTypeTrait<dtype::BFloat16>::enumv) {
201 202 203
        return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
                sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
                "cuda conv bwd_data", positive_attr, negative_attr);
204
    } else {
205 206 207
        return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>(
                sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
                "cuda conv bwd_data", positive_attr, negative_attr);
208 209 210 211
    }
}

size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
212 213
        const TensorLayout& filter, const TensorLayout& diff,
        const TensorLayout& grad) {
214
    AlgoBase::SizeArgs args(this, filter, diff, grad);
215 216
    return get_algorithm(this, filter, args.filter_meta, diff, grad)
            ->get_workspace_in_bytes(args);
217 218 219 220 221 222 223 224 225
}

const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
    return "CUDACONV0+CUDNN" CUDNN_VERSION_STR;
}

/* ============== ConvolutionBackwardFilterImpl ============== */

void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src,
226 227 228
                                         _megdnn_tensor_in diff,
                                         _megdnn_tensor_out grad,
                                         _megdnn_workspace workspace) {
229
    AlgoBase::ExecArgs args(this, src, diff, grad, workspace);
230 231
    auto algo = get_algorithm(this, src.layout, diff.layout, grad.layout,
                              args.grad_filter_meta);
232 233 234
    algo->check_workspace(args, workspace).exec(args);
}

235 236 237 238
std::vector<ConvolutionBackwardFilterImpl::Algorithm*>
ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
                                                  const TensorLayout& diff,
                                                  const TensorLayout& grad) {
239 240 241 242 243 244 245 246
    return megdnn::get_all_algorithms<ConvolutionBackwardFilterImpl>(
            {this, src, diff, grad});
}

ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
        const TensorLayout& src, const TensorLayout& diff,
        const TensorLayout& grad, size_t workspace_limit_in_bytes,
247 248
        const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
249
    auto fm = check_layout_fwd(src, grad, diff);
250
    return get_algorithm_heuristic(src, diff, grad, fm,
251 252
                                   workspace_limit_in_bytes, positive_attr,
                                   negative_attr);
253 254 255 256 257
}

ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
        const TensorLayout& src, const TensorLayout& diff,
258
        const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
259 260 261
        size_t workspace_limit_in_bytes,
        const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
262
    AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta);
263 264

    if (args.grad_filter_meta.group > 1 &&
265
        sm_algo_pack.chanwise.is_available_attribute(
266
                args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
267 268 269 270 271
        // prefer special chanwise impl
        return &sm_algo_pack.chanwise;
    }

    auto get_cudnn_algo =
272 273
            [this, &args, workspace_limit_in_bytes, positive_attr,
             negative_attr]() -> ConvolutionBackwardFilterImpl::AlgoBase* {
274 275 276 277
        auto cudnn_handle = cuda::cudnn_handle(this->handle());
        CUDNNBwdFilterDescs desc;
        args.init_desc(desc);

278
        // disable, segfault in megbrain, need further investigate.
279 280 281 282 283 284 285 286 287 288
#if 0
        auto is_heuristic_success =
                convolution::PerformanceModelBackwardFilter::
                        get_algo_backward_filter_success(
                                args, desc, workspace_limit_in_bytes, &algo);
        if (is_heuristic_success) {
            return sm_algo_pack.cudnn_from_enum(algo);
        }
#endif
#if CUDNN_MAJOR >= 7
289
        MEGDNN_MARK_USED_VAR(negative_attr);
290
        auto& cudnn = args.handle->cudnn();
291
        int max_count = 0;
292
        cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount(
293 294 295
                cudnn_handle, &max_count));
        SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count);
        int ret_count = 0;
296
        cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7(
297 298 299 300 301 302
                cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc,
                desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count,
                algo_perf.data()));
        for (int i = 0; i < ret_count; ++i) {
            if (algo_perf[i].memory > workspace_limit_in_bytes)
                continue;
303
            if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) {
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
                if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) {
                    return reinterpret_cast<AlgoBase*>(
                            sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
                }
            } else {
                return reinterpret_cast<AlgoBase*>(
                        sm_algo_pack.cudnn_from_enum(algo_perf[i].algo));
            }
        }
        return nullptr;
#else
        cudnnConvolutionBwdFilterAlgo_t algo;
        cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm(
                cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc,
                desc.conv_desc.desc, desc.grad_desc.desc,
                CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
                workspace_limit_in_bytes, &algo));
        auto&& cast_algo =
                reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo));
        return reinterpret_cast<AlgoBase*>(
324 325
                megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
                        cast_algo, positive_attr, negative_attr));
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
#endif
    };

    if (is_cudnn_supported(args.as_fwd_args())) {
        if (auto algo = get_cudnn_algo())
            return algo;
    }

    if (args.grad_filter_meta.group > 1) {
        auto orig_args = args;
        TensorLayout a, b;
        AlgoGroupConvGeneral::modify_size_args(args, a, b);
        if (is_cudnn_supported(args.as_fwd_args())) {
            if (auto algo = get_cudnn_algo())
                return sm_algo_pack.algo2gconv.at(algo);
        }
        args = orig_args;
    }

345
    if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
346 347 348
        return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
                sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
                "cuda conv bwd_filter", positive_attr, negative_attr);
349
    } else {
350 351 352
        return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>(
                sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
                "cuda conv bwd_filter", positive_attr, negative_attr);
353 354 355 356
    }
}

size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
357 358
        const TensorLayout& src, const TensorLayout& diff,
        const TensorLayout& grad) {
359
    AlgoBase::SizeArgs args(this, src, diff, grad);
360 361
    return get_algorithm(this, src, diff, grad, args.grad_filter_meta)
            ->get_workspace_in_bytes(args);
362 363 364 365 366 367 368
}

const char* ConvolutionBackwardFilterImpl::get_algorithm_set_name() const {
    return "CUDACONV0+CUDNN" CUDNN_VERSION_STR;
}

// vim: syntax=cpp.doxygen