argsort.cpp.hip 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
/**
 * \file dnn/src/rocm/argsort/argsort.cpp.hip
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "hcc_detail/hcc_defs_prologue.h"

#include "src/rocm/utils.h.hip"
#include "./argsort.h.hip"
#include "./bitonic_sort.h.hip"
#include "megdnn/basic_types.h"

#include "hipcub/device/device_radix_sort.hpp"
#include "hipcub/device/device_segmented_radix_sort.hpp"

using namespace megdnn;
using namespace rocm;

namespace {
struct StridedOffsetIterator {
    int bias, stride;

    StridedOffsetIterator(int bias_, int stride_)
            : bias(bias_), stride(stride_) {}

    __device__ __forceinline__ int operator[](int i) const {
        return stride * i + bias;
    }
};

bool use_bitonic(uint32_t /*M*/, uint32_t N) {
    // bitonic sort is preferred when N is small (alwyas faster than radix sort)
    return N <= BITONIC_SORT_MAX_LENGTH;
}

bool use_segmented(uint32_t M, uint32_t /*N*/) {
    // an empirical value:
    // sort(1, 1e6): 0.574ms
    // segsort({1,2,8,16}, 1e6): 7-8ms
    // sort(1, 1e7): 3.425ms
    // segsort({1,2,8,16}, 1e7): 71-84ms
    //
    // segsort is about 7x-10x slower than sort on small batches, so we can
    // expect it to be faster than sort when batch is large enough.
    return M >= 8;
}

__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) {
    uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
    if (i < n) {
        dst[i] = i % mod;
    }
}

template <typename ctype>
size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) {
    if (use_bitonic(M, N)) {
        return 0;
    }
    return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL,
                                 M, N, 0, sizeof(float)*8, NULL);
}
}  // anonymous namespace

template <typename KeyType, typename ValueType>
MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(
        bool is_ascending, void* workspace, size_t workspace_size,
        const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in,
        ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){
    hipError_t err;
    if (use_segmented(M, N)) {
        if (is_ascending) {
            err = hipcub::DeviceSegmentedRadixSort::SortPairs(
                    workspace, workspace_size, keys_in, keys_out, values_in,
                    values_out, N * M, M, StridedOffsetIterator(0, N),
                    StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
            hip_check(err);
        } else {
            err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending(
                    workspace, workspace_size, keys_in, keys_out, values_in,
                    values_out, N * M, M, StridedOffsetIterator(0, N),
                    StridedOffsetIterator(N, N), begin_bit, end_bit, stream);
            hip_check(err);
        }
    } else {
        if (is_ascending) {
            for (size_t i = 0; i < M; ++i) {
                err = hipcub::DeviceRadixSort::SortPairs(
                        workspace, workspace_size, keys_in + N * i,
                        keys_out + N * i, values_in + N * i, values_out + N * i,
                        N, begin_bit, end_bit, stream);
                hip_check(err);
                if (!keys_in) {
                    return workspace_size;
                }
            }
        } else {
            for (size_t i = 0; i < M; ++i) {
                err = hipcub::DeviceRadixSort::SortPairsDescending(
                        workspace, workspace_size, keys_in + N * i,
                        keys_out + N * i, values_in + N * i, values_out + N * i,
                        N, begin_bit, end_bit, stream);
                hip_check(err);
                if (!keys_in) {
                    return workspace_size;
                }
            }
        }
    }
    return workspace_size;
}

size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
                                           bool is_ascending,
                                           bool iptr_src_given) {
    size_t size = 0;
    switch (dtype.enumv().ev) {
#define cb(ctype)                                             \
    case DTypeTrait<ctype>::enumv:                            \
        size = get_sort_workspace<ctype>(M, N, is_ascending); \
        break;
        ARGSORT_FOREACH_CTYPE(cb)
#undef cb
        default:
            megdnn_throw("argsort only supports float, int32 and float16");
    }
    if (!iptr_src_given) {
        size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int);
    }
    return size;
}

template <typename dtype>
void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr,
                      void* workspace, uint32_t M, uint32_t N,
                      bool is_ascending, hipStream_t stream,
                      const int* iptr_src) {
    size_t wk_size = get_sort_workspace<dtype>(M, N, is_ascending);
    if (!iptr_src) {
        int* ptr = reinterpret_cast<int*>(static_cast<uint8_t*>(workspace) +
                                          DIVUP(wk_size, sizeof(float)) *
                                                  sizeof(float));
        kern_arange<<<DIVUP(N * M, 512), 512, 0, stream>>>(ptr, M * N, N);
        iptr_src = ptr;
    }

    if (use_bitonic(M, N)) {
        hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending,
                                stream));
    } else {
        cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src,
                       iptr, M, N, 0, sizeof(float)*8, stream);
    }
}

namespace megdnn {
namespace rocm {

#define INST_CUB_SORT(dtype)                                                 \
template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool,  \
                                    void*, size_t, const dtype*, dtype*,     \
                                    const dtype*, dtype*, uint32_t, uint32_t,\
                                    int, int, hipStream_t);

#define INST_FORWARD(dtype)                                                  \
template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*,     \
                                    uint32_t, uint32_t, bool, hipStream_t,   \
                                    const int*);
                                    
ARGSORT_FOREACH_CTYPE(INST_FORWARD)
INST_CUB_SORT(uint32_t)
// INST_CUB_SORT(uint64_t)
#undef INST_CUB_SORT
#undef INST_FORWARD
}
}  // namespace megdnn
// vim: ft=rocm syntax=rocm.doxygen