/** * \file dnn/src/rocm/topk/topk_radix.cpp.hip * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./topk_radix.h.hip" #include "src/rocm/utils.h.hip" #include #include "hipcub/device/device_scan.hpp" #include #include #if __CUDACC_VER_MAJOR__ < 9 #pragma message "topk is a little slower on cuda earlier than 9.0" // on cuda 9.0 and later, due to thread-divergent branches we should use // __syncwarp; and I am too lazy to implement a correct legacy version, so just // use __syncthreads instead for older cuda #define __syncwarp __syncthreads #endif using namespace megdnn; using namespace rocm; using namespace topk; using namespace internal; namespace rocm_topk_impl { const uint32_t WARP_SIZE = 32; static __device__ __forceinline__ uint32_t u32_from_64_low(uint64_t x) { return x; } static __device__ __forceinline__ uint32_t u32_from_64_high(uint64_t x) { return x >> 32; } template struct static_log2 { static const uint32_t val = static_log2::val + 1; }; template <> struct static_log2<1> { static const uint32_t val = 0; }; template struct DeviceScanPackedItem; template struct DeviceScanPackedItem<1, T> { __device__ __forceinline__ T load(T* data, uint32_t tid) { return data[tid]; } __device__ __forceinline__ void store(T* data, uint32_t tid, uint32_t s) { data[tid] = s; } }; template <> struct DeviceScanPackedItem<4, uint8_t> { uint8_t d0, d1, d2, d3; __device__ __forceinline__ uint32_t load(uint8_t* data, uint32_t tid) { uint32_t item = reinterpret_cast(data)[tid]; d3 = item >> 24; d2 = (item >> 16) & 0xFF; d1 = (item >> 8) & 0xFF; d0 = item & 0xFF; return d0 + d1 + d2 + d3; } __device__ __forceinline__ void store(uint8_t* data, uint32_t tid, uint32_t s) { uint8_t o3 = s, o2 = o3 - d3, o1 = o2 - d2, o0 = o1 - d1; reinterpret_cast(data)[tid] = (o3 << 24) | (o2 << 16) | (o1 << 8) | o0; } }; //! inclusive scan within a warp using register shuffle template __device__ __forceinline__ uint32_t device_scan_shfl_core(uint32_t s, uint32_t tid) { static const uint32_t SIZE_LOG2 = static_log2::val; uint32_t self_lane = tid % SIZE; #pragma unroll for (uint32_t step_log2 = 1; step_log2 <= SIZE_LOG2; ++step_log2) { uint32_t from_lane = (self_lane & ~((1u << step_log2) - 1)) + ((1 << (step_log2 - 1)) - 1); uint32_t valid_mask = (from_lane >= self_lane) - 1; uint32_t s_below = __shfl_up(s, self_lane - from_lane, SIZE); s += s_below & valid_mask; } return s; } /*! * \brief compute inplace inclusive prefix sum of \p data * * Note: no synchronization at the end */ template __device__ __forceinline__ void device_scan(uint32_t* data, uint32_t tid, uint32_t shard) { const uint32_t NR_WARP = SIZE / NR_SHARD / WARP_SIZE; #if __cplusplus > 199711L static_assert(NR_WARP <= WARP_SIZE || (NR_WARP & (NR_WARP - 1)), "bad params"); #endif __syncthreads(); DeviceScanPackedItem packed_item; uint32_t s = packed_item.load(data, tid); s = device_scan_shfl_core(s, tid); // sync between warps __shared__ uint32_t warp_sums_storage[NR_SHARD][NR_WARP]; uint32_t warp_id = tid / WARP_SIZE; uint32_t* warp_sums = warp_sums_storage[shard]; if ((tid & (WARP_SIZE - 1)) == WARP_SIZE - 1) { warp_sums[warp_id] = s; } __syncthreads(); for (uint32_t i = 0; i < warp_id; ++i) { s += warp_sums[i]; } packed_item.store(data, tid, s); } template __device__ __forceinline__ void device_scan_packed_accu32(T* data, uint32_t tid) { DeviceScanPackedItem scan_pack; __syncwarp(); uint32_t sum = scan_pack.load(data, tid); sum = device_scan_shfl_core(sum, tid); scan_pack.store(data, tid, sum); __syncwarp(); } namespace kth { const uint32_t BUCKET_BITS = 8, NR_BUCKET = 1 << BUCKET_BITS, LOCAL_CNT_SHARD = 16, BLOCK_DIM = NR_BUCKET; template struct enforce_const_u32 { static const uint32_t val = v; }; /*! * \brief compute scattered histogram for the whole input * * launch config: grid(X, batch), thread(BLOCK_DIM) * * Keys not starting with given prefix would be treated as max * * \param[in] input [batch, length] * \param[out] buckets [batch, X, NR_BUCKET] */ template static __global__ void compute_histogram(const ctype* input, uint32_t* bucket_cnt, uint32_t length, int32_t lda, uint32_t* prefix_ptr) { int32_t batch = blockIdx.y; input += batch * lda; bucket_cnt += (batch * gridDim.x + blockIdx.x) * NR_BUCKET; uint32_t prefix; if (prefix_valid) { prefix = prefix_ptr[batch]; } { // init bucket_cnt for (uint32_t i = threadIdx.x; i < NR_BUCKET; i += BLOCK_DIM) { bucket_cnt[i] = 0; } __syncthreads(); } { // accumulate uint32_t i = blockIdx.x * BLOCK_DIM + threadIdx.x, stride = BLOCK_DIM * gridDim.x; while (i < length) { uint32_t key = RadixConverter::to_radix(input[i]); if (prefix_valid) { const uint32_t mask = ((~0u) << ((prefix_valid ? shift : 0) + BUCKET_BITS)); key |= ((key & enforce_const_u32::val) == prefix) - 1; } uint32_t idx = (key >> shift) & ((1 << BUCKET_BITS) - 1); atomicAdd(bucket_cnt+idx, 1u); i += stride; } } __syncthreads(); } /*! * \brief update the values in \p prefix to k'th value in according to bucket * count, and update \p k * * launch config: grid(batch), thread(NR_BUCKET) */ template static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt, uint32_t* prefix, uint32_t* k, uint32_t k_init, uint32_t bucket_sharding_size, ctype* result) { __shared__ uint32_t cumsum_bucket_cnt[NR_BUCKET + 1]; uint32_t batch = blockIdx.x; bucket_cnt += batch * bucket_sharding_size * NR_BUCKET; uint32_t sum = 0; for (uint32_t i = 0; i < bucket_sharding_size; ++i) { sum += bucket_cnt[i * NR_BUCKET + threadIdx.x]; } if (!threadIdx.x) { cumsum_bucket_cnt[0] = 0; } const uint32_t i = threadIdx.x + 1; cumsum_bucket_cnt[i] = sum; device_scan(cumsum_bucket_cnt + 1, threadIdx.x, 0); __syncthreads(); uint32_t kv = first ? k_init : k[batch]; if ((cumsum_bucket_cnt[i] >= kv) & (cumsum_bucket_cnt[i - 1] < kv)) { uint32_t b = (i - 1) << shift; if (first) { prefix[batch] = b; } else if (last) { result[batch] = RadixConverter::from_radix(prefix[batch] | b); } else { prefix[batch] |= b; } if (!last) { k[batch] = kv - cumsum_bucket_cnt[i - 1]; } } //if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | // (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { // // impossible // int* bad = 0x0; // *bad = 23; //} } static uint32_t get_grid_dim_x(uint32_t length) { return std::max(length / (128 * BLOCK_DIM), 1); } } // namespace kth /*! * \brief select values smaller or larger than given threshold * * Note: we use register shuffle extensively to perform both reduce and scan. */ namespace select { struct LessPred { template __device__ __forceinline__ static bool cmp(ctype x, ctype y) { return x < y; } }; struct GreaterPred { template __device__ __forceinline__ static bool cmp(ctype x, ctype y) { return x > y; } }; const uint32_t REDUCE_WARP_SIZE = 16, REDUCE_SIZE = WARP_SIZE * 4, REDUCE_SHARD = 64; /*! * \brief reduce number of elements satisfying Pred in (N, M) mat to * (N, ceil(M / REDUCE_SIZE)) * * launch config: grid(X, batch), * thread(REDUCE_WARP_SIZE, REDUCE_SHARD) * * Each block computes REDUCE_SHARD outputs */ template static __global__ void kern_reduce_block_cnt(const ctype* input_data, const ctype* input_thresh, uint32_t length, int32_t lda, uint64_t* output, uint32_t output_width) { static const uint32_t BLOCK_DIM_X = REDUCE_WARP_SIZE, BLOCK_DIM_Y = REDUCE_SHARD; uint32_t batch = blockIdx.y, out_col = blockIdx.x * BLOCK_DIM_Y + threadIdx.y, col_begin = out_col * REDUCE_SIZE, col_end = min(col_begin + REDUCE_SIZE, length), tid_local = threadIdx.x; if (out_col >= output_width) { return; } uint32_t thresh = RadixConverter::to_radix(input_thresh[batch]); input_data += static_cast(batch) * lda; uint32_t sum_eq = 0, sum_lt = 0; for (uint32_t i = col_begin + tid_local; i < col_end; i += BLOCK_DIM_X) { uint32_t iv = RadixConverter::to_radix(input_data[i]); sum_eq += iv == thresh; sum_lt += Pred::cmp(iv, thresh); } #pragma unroll for (uint32_t step = REDUCE_WARP_SIZE / 2; step >= 1; step >>= 1) { sum_eq += __shfl_down(sum_eq, step, REDUCE_WARP_SIZE); sum_lt += __shfl_down(sum_lt, step, REDUCE_WARP_SIZE); } // reduce warp results to a single scalar if (!tid_local) { output[batch * output_width + out_col] = (static_cast(sum_eq) << 32) | sum_lt; } } static MEGDNN_NOINLINE hipError_t invoke_cub_scan(const uint64_t* input, uint64_t* output, void* workspace, size_t& workspace_size, uint32_t size, hipStream_t stream) { return hipcub::DeviceScan::InclusiveSum(workspace, workspace_size, input, output, size, stream); } static __global__ void kern_init_zero(uint64_t* dst) { dst[0] = 0; } /*! * \brief copy top-k values of each row from input to output * * launch config: grid(X, batch), * thread(WARP_SIZE, COPY_SHARD) */ template static __global__ void kern_copy(const ctype* input_data, const ctype* input_thresh, const uint64_t* scan, uint32_t scan_width, ctype* output_value, int32_t* output_idx, uint32_t length, uint32_t k, int32_t lda) { #if __cplusplus > 199711L static_assert(REDUCE_SIZE < 256, "local_sum_storage can not be uint8_t"); #endif static const uint32_t BLOCK_DIM_X = WARP_SIZE, BLOCK_DIM_Y = COPY_SHARD; uint32_t scan_col = blockIdx.x * BLOCK_DIM_Y + threadIdx.y; if (scan_col >= scan_width) { return; } uint32_t batch = blockIdx.y, inp_col_begin = min(scan_col * REDUCE_SIZE, length), inp_col_length = min(inp_col_begin + REDUCE_SIZE, length) - inp_col_begin, tid_local = threadIdx.x; uint32_t thresh = RadixConverter::to_radix(input_thresh[batch]); input_data += static_cast(batch) * lda + static_cast(inp_col_begin); __shared__ uint8_t local_sum_storage[BLOCK_DIM_Y][2][REDUCE_SIZE + 4]; uint8_t *local_sum_eq = local_sum_storage[threadIdx.y][0], *local_sum_lt = local_sum_storage[threadIdx.y][1]; if (!tid_local) { local_sum_eq[3] = 0; local_sum_lt[3] = 0; } local_sum_eq += 4; local_sum_lt += 4; const uint32_t WORKLOAD = REDUCE_SIZE / WARP_SIZE; #pragma unroll for (uint32_t j = 0; j < WORKLOAD; ++j) { uint32_t i = j * BLOCK_DIM_X + tid_local; if (i < inp_col_length) { uint32_t iv = RadixConverter::to_radix(input_data[i]); local_sum_eq[i] = iv == thresh; local_sum_lt[i] = Pred::cmp(iv, thresh); } else { local_sum_eq[i] = 0; local_sum_lt[i] = 0; } } device_scan_packed_accu32(local_sum_eq, tid_local); device_scan_packed_accu32(local_sum_lt, tid_local); scan += batch * scan_width; uint64_t scan_prev_pack = scan[static_cast(scan_col) - 1], k_offset_pack = scan_prev_pack - scan[-1], scan_self_pack = scan[scan_col] - scan_prev_pack; #define unpack(name) \ uint32_t name##_eq = u32_from_64_high(name##_pack), \ name##_lt = u32_from_64_low(name##_pack) unpack(k_offset); unpack(scan_self); #undef unpack uint32_t allowed_eq = k - min(k, (u32_from_64_low(scan[scan_width - 1]) - u32_from_64_low(scan[-1]))), ls_lt_max = k - min(k_offset_lt, k), ls_eq_max = allowed_eq - min(allowed_eq, k_offset_eq); if ((scan_self_lt && ls_lt_max) || (scan_self_eq && ls_eq_max)) { #pragma unroll for (uint32_t j = 0; j < WORKLOAD; ++j) { int32_t i = j * BLOCK_DIM_X + tid_local; uint32_t cur_lt = local_sum_lt[i], cur_eq = local_sum_eq[i]; bool is_lt = cur_lt <= ls_lt_max && cur_lt != local_sum_lt[i - 1]; bool is_eq = cur_eq <= ls_eq_max && cur_eq != local_sum_eq[i - 1]; // exactly one should be true if (is_lt || is_eq) { uint32_t off_lt = cur_lt + k_offset_lt - 1; uint32_t off_eq = cur_eq + k_offset_eq - 1 + (k - allowed_eq); uint32_t ocol = is_lt ? off_lt : off_eq; output_value[batch * k + ocol] = input_data[i]; output_idx[batch * k + ocol] = i + inp_col_begin; } } } } //! get workspace for scan, aligned to uint64_t static size_t get_scan_workspace(uint32_t size) { size_t wk = 0; hipError_t err = invoke_cub_scan(NULL, NULL, NULL, wk, size, NULL); if (err != hipSuccess) { fprintf(stderr, "topk: cub scan failed: %s (%d)\n", hipGetErrorString(err), static_cast(err)); megdnn_trap(); } return ((wk - 1) / sizeof(uint64_t) + 1) * sizeof(uint64_t); } } // namespace select } // namespace rocm_topk_impl uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length, uint32_t grid_dim_y_limit) { using namespace rocm_topk_impl::kth; uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * sizeof(uint32_t); } template hipError_t topk::find_kth_radix(const ctype* input, ctype* output, void* workspace, uint32_t batch, uint32_t length, int32_t lda, int32_t k, uint32_t grid_dim_y_limit, hipStream_t stream) { using namespace rocm_topk_impl::kth; if (!k) { return hipErrorInvalidValue; } if (k < 0) { k = length + k + 1; } if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) { // no c++11 in megdnn cuda; so we just trap instead of using static // assert megdnn_trap(); } uint32_t batch_idx = 0; uint32_t grid_dim_x = get_grid_dim_x(length); uint32_t grid_dim_y = 1; while (batch_idx < batch) { if (batch - batch_idx >= grid_dim_y_limit) { grid_dim_y = grid_dim_y_limit; } else { grid_dim_y = batch - batch_idx; } dim3 grid_dim(grid_dim_x, grid_dim_y); uint32_t* dev_k = static_cast(workspace); uint32_t* dev_prefix = dev_k + grid_dim_y; uint32_t* bucket_cnt = dev_prefix + grid_dim_y; compute_histogram<<>>( input + batch_idx * lda, bucket_cnt, length, lda, nullptr); // use float to make compiler happy; it is not used since last == false update_prefix_and_k <<>>( bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); compute_histogram<<>>( input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); update_prefix_and_k <<>>( bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); compute_histogram<<>>( input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); update_prefix_and_k <<>>( bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); compute_histogram<<>>( input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); update_prefix_and_k <<>>(bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, output + batch_idx); batch_idx += grid_dim_y; } return hipGetLastError(); } template hipError_t topk::topk_select(const ctype* input, const ctype* thresh, ctype* output_value, int32_t* output_idx, void* workspace, uint32_t batch, uint32_t length, int32_t lda, int32_t k, uint32_t batch_upper_limit, hipStream_t stream) { using namespace rocm_topk_impl; using namespace rocm_topk_impl::select; uint32_t length_split = DIVUP(length, REDUCE_SIZE); void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, uint64_t*, uint32_t); void (*kptr_copy)(const ctype*, const ctype*, const uint64_t*, uint32_t, ctype*, int32_t*, uint32_t, uint32_t, int32_t); int kern_copy_shard; { int grid, block; hipError_t err = hipOccupancyMaxPotentialBlockSize( &grid, &block, kern_copy); if (err) { return err; } kern_copy_shard = block / (WARP_SIZE * 8) * 8; if (!kern_copy_shard) { fprintf(stderr, "topk: failed to launch: block=%d\n", block); return hipErrorLaunchOutOfResources; } } #define CASE_SHARD_ON(pred, n) \ case n: \ kptr_copy = kern_copy; \ break #define CASE_SHARD(pred) \ switch (kern_copy_shard) { \ CASE_SHARD_ON(pred, 8); \ CASE_SHARD_ON(pred, 16); \ CASE_SHARD_ON(pred, 24); \ CASE_SHARD_ON(pred, 32); \ default: \ fprintf(stderr, "topk: failed to launch: shard=%d\n", \ kern_copy_shard); \ return hipErrorLaunchOutOfResources; \ } if (k < 0) { k = -k; kptr_reduce_block_cnt = kern_reduce_block_cnt; CASE_SHARD(GreaterPred); } else { kptr_reduce_block_cnt = kern_reduce_block_cnt; CASE_SHARD(LessPred); } #undef CASE_SHARD #undef CASE_SHARD_ON uint32_t batch_idx = 0; uint32_t batch_real = 1; while (batch_idx < batch) { if (batch - batch_idx >= batch_upper_limit) { batch_real = batch_upper_limit; } else { batch_real = batch - batch_idx; } size_t scan_size = batch_real * length_split; size_t scan_wk = get_scan_workspace(scan_size); uint64_t *scan_inp = static_cast(workspace) + scan_wk / sizeof(uint64_t), *scan_out = scan_inp + scan_size; // reduce to scan_inp kptr_reduce_block_cnt<<< dim3(DIVUP(length_split, REDUCE_SHARD), batch_real), dim3(REDUCE_WARP_SIZE, REDUCE_SHARD), 0, stream>>>( input + batch_idx * lda, thresh + batch_idx, length, lda, scan_inp, length_split); // scan to scan_out scan_out += 1; // set scan[-1] to 0 hipError_t err = invoke_cub_scan(scan_inp, scan_out, workspace, scan_wk, scan_size, stream); if (err != hipSuccess) { return err; } kern_init_zero<<<1, 1, 0, stream>>>(scan_out - 1); // copy result kptr_copy<<>>( input + batch_idx * lda, thresh + batch_idx, scan_out, length_split, output_value + std::abs(k) * batch_idx, output_idx + std::abs(k) * batch_idx, length, k, lda); batch_idx += batch_real; } return hipGetLastError(); } uint32_t topk::topk_select_workspace(uint32_t batch, uint32_t length) { using namespace rocm_topk_impl::select; size_t scan_size = batch * DIVUP(length, REDUCE_SIZE); return get_scan_workspace(scan_size) + sizeof(uint64_t) * (scan_size * 2 + 1); } namespace megdnn { namespace rocm { namespace topk { #define INST(t) \ template hipError_t find_kth_radix(const t*, t*, void*, uint32_t, \ uint32_t, int32_t, int32_t, \ uint32_t, hipStream_t); \ template hipError_t topk_select(const t*, const t*, t*, int32_t*, \ void*, uint32_t, uint32_t, int32_t, \ int32_t, uint32_t, hipStream_t) INST(float); INST(int32_t); // DNN_INC_FLOAT16(INST(dt_float16)); #undef INST } // namespace topk } // namespace rocm } // namespace megdnn // vim: ft=rocm syntax=rocm.doxygen