提交 54e79dd1 编写于 作者: M Megvii Engine Team

perf(mgb/cuda): do not call cudaGetDeviceProperties to avoid io traffic

GitOrigin-RevId: 6aa35928c8ec737d244fdb3ca9639ae49b03b284
上级 5f171298
......@@ -22,20 +22,25 @@ template <typename ctype>
void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
const ctype* data, ctype* values,
int* indices, void* workspace) {
auto stream = concrete_handle(handle())->stream();
auto _handle = concrete_handle(handle());
auto stream = _handle->stream();
size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1];
switch (param().mode) {
case Param::Mode::KTH_ONLY:
cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m,
n, lda, k, stream));
n, lda, k, grid_dim_y_limit,
stream));
return;
case Param::Mode::VALUE_IDX_NOSORT: {
WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}};
auto thresh = static_cast<ctype*>(wk_bundle.get(0));
auto real_wk = wk_bundle.get(1);
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
lda, k, stream));
lda, k, grid_dim_y_limit,
stream));
cuda_check(topk::topk_select<ctype>(data, thresh, values, indices,
real_wk, m, n, lda, k, stream));
real_wk, m, n, lda, k,
grid_dim_y_limit, stream));
return;
}
case Param::Mode::VALUE_IDX_SORTED: {
......@@ -48,10 +53,11 @@ void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2));
auto real_wk = wk_bundle.get(3);
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
lda, k, stream));
lda, k, grid_dim_y_limit,
stream));
cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values,
nosort_idx, real_wk, m, n, lda,
k, stream));
k, grid_dim_y_limit, stream));
argsort::forward(nosort_values, values, indices, real_wk, m,
std::abs(k), k > 0, stream, nosort_idx);
return;
......@@ -89,9 +95,11 @@ size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data,
MEGDNN_MARK_USED_VAR(indices);
size_t m = data[0], n = data[1];
size_t kabs = std::abs(k);
size_t grid_dim_y_limit =
concrete_handle(handle())->device_prop().maxGridSize[1];
megdnn_assert(std::max(m, n) <=
static_cast<size_t>(std::numeric_limits<int>::max()));
size_t kth = topk::find_kth_radix_workspace(m, n),
size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit),
sel = topk::topk_select_workspace(m, n);
auto ctsize = data.dtype.size();
switch (param().mode) {
......
......@@ -468,17 +468,9 @@ static size_t get_scan_workspace(uint32_t size) {
} // namespace select
} // namespace cuda_topk_impl
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) {
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length,
uint32_t grid_dim_y_limit) {
using namespace cuda_topk_impl::kth;
int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch;
return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) *
sizeof(uint32_t);
......@@ -488,6 +480,7 @@ template <typename ctype>
cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
void* workspace, uint32_t batch,
uint32_t length, int32_t lda, int32_t k,
uint32_t grid_dim_y_limit,
cudaStream_t stream) {
using namespace cuda_topk_impl::kth;
if (!k) {
......@@ -502,16 +495,6 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
megdnn_trap();
}
int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
uint32_t batch_idx = 0;
uint32_t grid_dim_x = get_grid_dim_x(length);
uint32_t grid_dim_y = 1;
......@@ -567,20 +550,11 @@ template <typename ctype>
cudaError_t topk::topk_select(const ctype* input, const ctype* thresh,
ctype* output_value, int32_t* output_idx,
void* workspace, uint32_t batch, uint32_t length,
int32_t lda, int32_t k, cudaStream_t stream) {
int32_t lda, int32_t k,
uint32_t batch_upper_limit, cudaStream_t stream) {
using namespace cuda_topk_impl;
using namespace cuda_topk_impl::select;
int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t batch_upper_limit = prop.maxGridSize[1];
uint32_t length_split = DIVUP(length, REDUCE_SIZE);
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t,
......@@ -688,10 +662,10 @@ namespace topk {
#define INST(t) \
template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \
uint32_t, int32_t, int32_t, \
cudaStream_t); \
uint32_t, cudaStream_t); \
template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \
void*, uint32_t, uint32_t, int32_t, \
int32_t, cudaStream_t)
int32_t, uint32_t, cudaStream_t)
INST(float);
INST(int32_t);
#undef INST
......
......@@ -76,10 +76,12 @@ struct RadixConverter<int32_t> {
template <typename ctype>
cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace,
uint32_t batch, uint32_t length, int32_t lda,
int32_t k, cudaStream_t stream);
int32_t k, uint32_t grid_dim_y_limit,
cudaStream_t stream);
//! get workspace in bytes
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length);
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length,
uint32_t grid_dim_y_limit);
/*!
* \brief select values from rows of input that compare to thresh as specified
......@@ -90,7 +92,8 @@ template <typename ctype>
cudaError_t topk_select(const ctype* input, const ctype* thresh,
ctype* output_value, int32_t* output_idx,
void* workspace, uint32_t batch, uint32_t length,
int32_t lda, int32_t k, cudaStream_t stream);
int32_t lda, int32_t k, uint32_t batch_upper_limit,
cudaStream_t stream);
uint32_t topk_select_workspace(uint32_t batch, uint32_t length);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册