algo.cpp 3.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution3d/forward/algo.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./algo.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

Convolution3DForwardImpl::AlgoPack::AlgoPack() {
    non_cudnn_algos.push_back(&chanwise);
    non_cudnn_algos.push_back(&inplace_matmul);
    non_cudnn_algos.push_back(&a1x1x1);

    all_algos.push_back(&chanwise);
24

25 26
    fill_cudnn_algos();
    for (auto &&i: cudnn) {
27
       all_algos.push_back(&i);
28 29
    }
    all_algos.push_back(&inplace_matmul);
30
    all_algos.push_back(&a1x1x1);
31 32 33 34 35 36 37 38 39 40 41 42 43 44
    all_algos.reserve(all_algos.size() * 2);

    // add gconv algos by AlgoGroupConvGeneral
    auto all_algos_data = all_algos.data();
    for (size_t i = 1; i < all_algos.size(); ++ i) {
        gconv.push_back({all_algos[i]});
    }
    for (size_t i = 1; i < all_algos.size(); ++ i) {
        algo2gconv[all_algos[i]] = &gconv[i - 1];
    }
    for (auto &&i: gconv) {
        all_algos.push_back(&i);
    }
    megdnn_assert(all_algos_data == all_algos.data());
45
    non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul
46
    non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1
47 48 49 50

    for (auto&& algo : all_algos) {
        m_all_algos_map.emplace(algo->info().desc, algo);
    }
51 52
}

53 54
MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl)

55 56 57 58 59 60 61
Convolution3DForwardImpl::AlgoCUDNN*
Convolution3DForwardImpl::AlgoPack::cudnn_from_enum(
        cudnnConvolutionFwdAlgo_t algo) {
    for (auto &&i: cudnn) {
        if (i.cudnn_enum() == algo)
            return &i;
    }
M
Megvii Engine Team 已提交
62 63
    megdnn_throw(ssprintf("can not find cudnn fwd algorithm %d",
                          static_cast<int>(algo)));
64 65 66 67 68 69 70 71
}

Convolution3DForwardImpl::AlgoPack Convolution3DForwardImpl::sm_algo_pack;

Convolution3DForwardImpl::AlgoBase::SizeArgs::SizeArgs(
        Convolution3DForwardImpl *o,
        const TensorLayout &src, const TensorLayout &filter,
        const TensorLayout &dst):
72
    SizeArgs(o, src, o->make_canonized_filter_meta(src.ndim, filter), dst)
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
{
}

Convolution3DForwardImpl::AlgoBase::SizeArgs::SizeArgs(
        Convolution3DForwardImpl *o,
        const TensorLayout &src, const CanonizedFilterMeta &filter,
        const TensorLayout &dst):
    ForwardSizeArgs{
        concrete_handle(o->handle()),
        &src, filter, &dst,
        o->param().data_type
    },
    opr{o}
{
}

Convolution3DForwardImpl::AlgoBase::ExecArgs::ExecArgs(
        Convolution3DForwardImpl *opr,
        _megdnn_tensor_in src,
        _megdnn_tensor_in filter,
        _megdnn_tensor_out dst,
        _megdnn_workspace workspace):
    SizeArgs(opr, src.layout, filter.layout, dst.layout),
    src_tensor{&src}, filter_tensor{&filter}, dst_tensor{&dst},
    workspace{workspace}
{
}

std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const {
    auto &&fm = filter_meta;
    MEGDNN_MARK_USED_VAR(fm);
M
Megvii Engine Team 已提交
104 105 106 107 108 109 110 111 112 113
    return ssprintf(
            "src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, "
            "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, "
            "dtype=%s,%s",
            src_layout->to_string().c_str(), fm.group, fm.ocpg, fm.icpg,
            fm.spatial[0], fm.spatial[1], fm.spatial[2],
            dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1],
            fm.padding[2], fm.stride[0], fm.stride[1], fm.stride[2],
            fm.dilation[0], fm.dilation[1], fm.dilation[2], !fm.should_flip,
            src_layout->dtype.name(), dst_layout->dtype.name());
114 115 116
}

// vim: syntax=cpp.doxygen