opr_impl.cpp 6.0 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/topk/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./opr_impl.h"
#include "./topk_radix.cuh"
#include "src/common/utils.h"
#include "src/cuda/argsort/argsort.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

template <typename ctype>
void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
                                   const ctype* data, ctype* values,
                                   int* indices, void* workspace) {
25 26 27
    auto _handle = concrete_handle(handle());
    auto stream = _handle->stream();
    size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1];
28 29 30
    switch (param().mode) {
        case Param::Mode::KTH_ONLY:
            cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m,
31 32
                                                   n, lda, k, grid_dim_y_limit,
                                                   stream));
33 34 35 36 37 38
            return;
        case Param::Mode::VALUE_IDX_NOSORT: {
            WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}};
            auto thresh = static_cast<ctype*>(wk_bundle.get(0));
            auto real_wk = wk_bundle.get(1);
            cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
39 40
                                                   lda, k, grid_dim_y_limit,
                                                   stream));
41
            cuda_check(topk::topk_select<ctype>(data, thresh, values, indices,
42 43
                                                real_wk, m, n, lda, k,
                                                grid_dim_y_limit, stream));
44 45 46 47 48 49 50 51 52 53 54 55
            return;
        }
        case Param::Mode::VALUE_IDX_SORTED: {
            WorkspaceBundle wk_bundle{
                    workspace,
                    {m * sizeof(ctype), m * std::abs(k) * sizeof(ctype),
                     m * std::abs(k) * sizeof(int32_t), 1}};
            auto thresh = static_cast<ctype*>(wk_bundle.get(0)),
                 nosort_values = static_cast<ctype*>(wk_bundle.get(1));
            auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2));
            auto real_wk = wk_bundle.get(3);
            cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
56 57
                                                   lda, k, grid_dim_y_limit,
                                                   stream));
58 59
            cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values,
                                                nosort_idx, real_wk, m, n, lda,
60
                                                k, grid_dim_y_limit, stream));
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            argsort::forward(nosort_values, values, indices, real_wk, m,
                             std::abs(k), k > 0, stream, nosort_idx);
            return;
        }
    }
    megdnn_throw("bad topk mode");
}

void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
                       int32_t* indices, _megdnn_workspace workspace) {
    switch (data.layout.dtype.enumv()) {
        case DTypeEnum::Float32:
            dispatch_with_ctype<float>(k, data.layout[0], data.layout[1],
                                       data.layout.stride[0], data.ptr<float>(),
                                       values.ptr<float>(), indices,
                                       workspace.raw_ptr);
            return;
        case DTypeEnum::Int32:
            dispatch_with_ctype<int32_t>(k, data.layout[0], data.layout[1],
                                       data.layout.stride[0], data.ptr<int32_t>(),
                                       values.ptr<int32_t>(), indices,
                                       workspace.raw_ptr);
            return;
84 85 86 87 88 89 90 91
#if !MEGDNN_DISABLE_FLOAT16
        case DTypeEnum::Float16:
            dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1],
                                       data.layout.stride[0], data.ptr<dt_float16>(),
                                       values.ptr<dt_float16>(), indices,
                                       workspace.raw_ptr);
            return;
#endif
92 93
        default:
            megdnn_throw(
94 95
                    ssprintf("only float32, int32 and float16 supported for "
                             "cuda topk, got: %s",
96 97 98 99 100 101 102 103 104 105 106
                             data.layout.dtype.name()));
    }
}

size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data,
                                        const TensorLayout& values,
                                        const TensorLayout& indices) {
    MEGDNN_MARK_USED_VAR(values);
    MEGDNN_MARK_USED_VAR(indices);
    size_t m = data[0], n = data[1];
    size_t kabs = std::abs(k);
107 108
    size_t grid_dim_y_limit =
            concrete_handle(handle())->device_prop().maxGridSize[1];
109 110
    megdnn_assert(std::max(m, n) <=
                  static_cast<size_t>(std::numeric_limits<int>::max()));
111
    size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit),
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
           sel = topk::topk_select_workspace(m, n);
    auto ctsize = data.dtype.size();
    switch (param().mode) {
        case Param::Mode::KTH_ONLY:
            return kth;
        case Param::Mode::VALUE_IDX_NOSORT:
            return WorkspaceBundle{nullptr, {m * ctsize, std::max(kth, sel)}}
                    .total_size_in_bytes();
        case Param::Mode::VALUE_IDX_SORTED:
            return WorkspaceBundle{
                    nullptr,
                    {m * ctsize, m * kabs * ctsize, m * kabs * sizeof(int32_t),
                     std::max(std::max(kth, sel),
                              argsort::get_fwd_workspace_in_bytes(
                                      m, kabs, data.dtype, k > 0, true))}}
                    .total_size_in_bytes();
    }
    megdnn_throw("bad topk mode");
}

// vim: syntax=cpp.doxygen