algo.cpp 4.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution/backward_data/algo.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23
 */

#include "./algo.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
    non_cudnn_algos.push_back(&chanwise);
    non_cudnn_algos.push_back(&chanwise_small);
    non_cudnn_algos.push_back(&matmul);

24 25
    all_algos.push_back(&chanwise);        // prefer chanwise
    all_algos.push_back(&chanwise_small);  // prefer small chanwise
26 27

    fill_cudnn_algos();
28
    for (auto&& i : cudnn) {
29 30 31 32
        all_algos.push_back(&i);
    }
    all_algos.push_back(&matmul);

33 34 35 36 37 38
    fill_int8_dp4a_algos();
    for (auto&& algo : int8_nchw4_dotprod) {
        all_algos.push_back(&algo);
        int8_algos.push_back(&algo);
    }

39 40 41
    int8_algos.push_back(&int8_nchw_dotprod);
    all_algos.push_back(&int8_nchw_dotprod);

42 43 44 45
    all_algos.reserve(all_algos.size() * 2);

    // add gconv algos by AlgoGroupConvGeneral
    auto all_algos_data = all_algos.data();
46
    size_t group_algo_start = 2;
47
    for (size_t i = group_algo_start; i < all_algos.size(); ++i) {
48 49
        gconv.push_back({all_algos[i]});
    }
50
    for (size_t i = group_algo_start; i < all_algos.size(); ++i) {
51
        algo2gconv[all_algos[i]] = &gconv[i - group_algo_start];
52
    }
53
    for (auto&& i : gconv) {
54 55 56 57
        all_algos.push_back(&i);
    }
    megdnn_assert(all_algos_data == all_algos.data());

58
    non_cudnn_algos.push_back(all_algos.rbegin()[0]);  // group matmul
59 60
    all_algos.push_back(&bfloat16);
    bfloat16_algos.push_back(&bfloat16);
61 62 63 64

    for (auto&& algo : all_algos) {
        m_all_algos_map.emplace(algo->info().desc, algo);
    }
65 66
}

67 68
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl)

69 70 71
ConvolutionBackwardDataImpl::AlgoCUDNN*
ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum(
        cudnnConvolutionBwdDataAlgo_t algo) {
72
    for (auto&& i : cudnn) {
73 74 75
        if (i.cudnn_enum() == algo)
            return &i;
    }
76 77 78
    megdnn_throw(
            megdnn_mangle(ssprintf("can not find cudnn bwd_data algorithm %d",
                                   static_cast<int>(algo))));
79 80 81 82 83
}

ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;

ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
84 85 86 87
        ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
        const TensorLayout& diff, const TensorLayout& grad)
        : SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff,
                   grad) {}
88 89

ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
90 91 92 93 94 95 96 97 98
        ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
        const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
        const TensorLayout& grad)
        : handle{concrete_handle(o->handle())},
          filter_meta{filter_meta},
          diff_layout{&diff},
          grad_layout{&grad},
          filter_layout{&filter},
          opr{o} {}
99 100

ConvolutionBackwardDataImpl::AlgoBase::ExecArgs::ExecArgs(
101 102 103 104 105 106 107 108
        ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
        _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        _megdnn_workspace workspace)
        : SizeArgs(opr, filter.layout, diff.layout, grad.layout),
          filter_tensor{&filter},
          diff_tensor{&diff},
          grad_tensor{&grad},
          workspace{workspace} {}
109 110

std::string ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::to_string() const {
111
    auto&& fm = filter_meta;
112 113
    MEGDNN_MARK_USED_VAR(fm);
    return megdnn_mangle(ssprintf(
114 115 116 117 118 119 120
            "filter=%u{%u,%u,%u,%u}, diff=%s, grad=%s, "
            "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s",
            fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1],
            diff_layout->to_string().c_str(), grad_layout->to_string().c_str(),
            fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1],
            fm.dilation[0], fm.dilation[1], !fm.should_flip,
            diff_layout->dtype.name(), grad_layout->dtype.name()));
121 122 123
}

// vim: syntax=cpp.doxygen