提交 83cf4ee6 编写于 作者: M Megvii Engine Team

refactor(dnn/rocm): remove some useless includes

GitOrigin-RevId: 3d2c315a368f7307a88ba37f0674a36072281578
上级 323a4642
......@@ -174,7 +174,7 @@ template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \
ARGSORT_FOREACH_CTYPE(INST_FORWARD)
INST_CUB_SORT(uint32_t)
// INST_CUB_SORT(uint64_t)
INST_CUB_SORT(uint64_t)
#undef INST_CUB_SORT
#undef INST_FORWARD
}
......
......@@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const int* iptr_src = NULL);
//! iterate over all supported data types
// device_radix_sort does not support dt_float16 dtype(half_float::half in rocm)
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16))
......
......@@ -14,8 +14,6 @@
#include "./argsort.h.hip"
#include "./backward.h.hip"
// #include "src/rocm/utils.h"
using namespace megdnn;
using namespace rocm;
using namespace argsort;
......
......@@ -11,13 +11,9 @@
#include "hcc_detail/hcc_defs_prologue.h"
#include "./bitonic_sort.h.hip"
// #include "src/cuda/query_blocksize.cuh"
// #include "megdnn/dtype.h"
#include "megdnn/dtype.h"
// #if __CUDACC_VER_MAJOR__ < 9
// #pragma message "warp sync disabled due to insufficient cuda version"
#define __syncwarp __syncthreads
// #endif
#include <algorithm>
#include <cmath>
......@@ -84,17 +80,17 @@ struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t min() { return INT_MIN; }
};
// #if !MEGDNN_DISABLE_FLOAT16
// template <>
// struct NumTrait<dt_float16> {
// static __device__ __forceinline__ dt_float16 max() {
// return std::numeric_limits<dt_float16>::max();
// }
// static __device__ __forceinline__ dt_float16 min() {
// return std::numeric_limits<dt_float16>::lowest();
// }
// };
// #endif
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NumTrait<dt_float16> {
static __device__ __forceinline__ dt_float16 max() {
return std::numeric_limits<dt_float16>::max();
}
static __device__ __forceinline__ dt_float16 min() {
return std::numeric_limits<dt_float16>::lowest();
}
};
#endif
struct LessThan {
template <typename Key, typename Value>
......@@ -310,7 +306,7 @@ namespace rocm {
INST(float, int);
INST(int32_t, int);
// DNN_INC_FLOAT16(INST(dt_float16, int));
DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST
} // namespace megdnn
......
......@@ -18,13 +18,7 @@
#include <algorithm>
#include <cmath>
#if __CUDACC_VER_MAJOR__ < 9
#pragma message "topk is a little slower on cuda earlier than 9.0"
// on cuda 9.0 and later, due to thread-divergent branches we should use
// __syncwarp; and I am too lazy to implement a correct legacy version, so just
// use __syncthreads instead for older cuda
#define __syncwarp __syncthreads
#endif
using namespace megdnn;
using namespace rocm;
......@@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt,
}
}
//if ((cumsum_bucket_cnt[NR_BUCKET] < kv) |
// (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) {
// // impossible
// int* bad = 0x0;
// *bad = 23;
//}
if ((cumsum_bucket_cnt[NR_BUCKET] < kv) |
(cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) {
// impossible
int* bad = 0x0;
*bad = 23;
}
}
static uint32_t get_grid_dim_x(uint32_t length) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册