提交 54e79dd1 编写于 作者: M Megvii Engine Team

perf(mgb/cuda): do not call cudaGetDeviceProperties to avoid io traffic

GitOrigin-RevId: 6aa35928c8ec737d244fdb3ca9639ae49b03b284
上级 5f171298
...@@ -22,20 +22,25 @@ template <typename ctype> ...@@ -22,20 +22,25 @@ template <typename ctype>
void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
const ctype* data, ctype* values, const ctype* data, ctype* values,
int* indices, void* workspace) { int* indices, void* workspace) {
auto stream = concrete_handle(handle())->stream(); auto _handle = concrete_handle(handle());
auto stream = _handle->stream();
size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1];
switch (param().mode) { switch (param().mode) {
case Param::Mode::KTH_ONLY: case Param::Mode::KTH_ONLY:
cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m, cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m,
n, lda, k, stream)); n, lda, k, grid_dim_y_limit,
stream));
return; return;
case Param::Mode::VALUE_IDX_NOSORT: { case Param::Mode::VALUE_IDX_NOSORT: {
WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}}; WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}};
auto thresh = static_cast<ctype*>(wk_bundle.get(0)); auto thresh = static_cast<ctype*>(wk_bundle.get(0));
auto real_wk = wk_bundle.get(1); auto real_wk = wk_bundle.get(1);
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
lda, k, stream)); lda, k, grid_dim_y_limit,
stream));
cuda_check(topk::topk_select<ctype>(data, thresh, values, indices, cuda_check(topk::topk_select<ctype>(data, thresh, values, indices,
real_wk, m, n, lda, k, stream)); real_wk, m, n, lda, k,
grid_dim_y_limit, stream));
return; return;
} }
case Param::Mode::VALUE_IDX_SORTED: { case Param::Mode::VALUE_IDX_SORTED: {
...@@ -48,10 +53,11 @@ void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, ...@@ -48,10 +53,11 @@ void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda,
auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2)); auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2));
auto real_wk = wk_bundle.get(3); auto real_wk = wk_bundle.get(3);
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n,
lda, k, stream)); lda, k, grid_dim_y_limit,
stream));
cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values, cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values,
nosort_idx, real_wk, m, n, lda, nosort_idx, real_wk, m, n, lda,
k, stream)); k, grid_dim_y_limit, stream));
argsort::forward(nosort_values, values, indices, real_wk, m, argsort::forward(nosort_values, values, indices, real_wk, m,
std::abs(k), k > 0, stream, nosort_idx); std::abs(k), k > 0, stream, nosort_idx);
return; return;
...@@ -89,9 +95,11 @@ size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data, ...@@ -89,9 +95,11 @@ size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data,
MEGDNN_MARK_USED_VAR(indices); MEGDNN_MARK_USED_VAR(indices);
size_t m = data[0], n = data[1]; size_t m = data[0], n = data[1];
size_t kabs = std::abs(k); size_t kabs = std::abs(k);
size_t grid_dim_y_limit =
concrete_handle(handle())->device_prop().maxGridSize[1];
megdnn_assert(std::max(m, n) <= megdnn_assert(std::max(m, n) <=
static_cast<size_t>(std::numeric_limits<int>::max())); static_cast<size_t>(std::numeric_limits<int>::max()));
size_t kth = topk::find_kth_radix_workspace(m, n), size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit),
sel = topk::topk_select_workspace(m, n); sel = topk::topk_select_workspace(m, n);
auto ctsize = data.dtype.size(); auto ctsize = data.dtype.size();
switch (param().mode) { switch (param().mode) {
......
...@@ -468,17 +468,9 @@ static size_t get_scan_workspace(uint32_t size) { ...@@ -468,17 +468,9 @@ static size_t get_scan_workspace(uint32_t size) {
} // namespace select } // namespace select
} // namespace cuda_topk_impl } // namespace cuda_topk_impl
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) { uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length,
uint32_t grid_dim_y_limit) {
using namespace cuda_topk_impl::kth; using namespace cuda_topk_impl::kth;
int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch;
return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) *
sizeof(uint32_t); sizeof(uint32_t);
...@@ -488,6 +480,7 @@ template <typename ctype> ...@@ -488,6 +480,7 @@ template <typename ctype>
cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
void* workspace, uint32_t batch, void* workspace, uint32_t batch,
uint32_t length, int32_t lda, int32_t k, uint32_t length, int32_t lda, int32_t k,
uint32_t grid_dim_y_limit,
cudaStream_t stream) { cudaStream_t stream) {
using namespace cuda_topk_impl::kth; using namespace cuda_topk_impl::kth;
if (!k) { if (!k) {
...@@ -502,16 +495,6 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, ...@@ -502,16 +495,6 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
megdnn_trap(); megdnn_trap();
} }
int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
uint32_t batch_idx = 0; uint32_t batch_idx = 0;
uint32_t grid_dim_x = get_grid_dim_x(length); uint32_t grid_dim_x = get_grid_dim_x(length);
uint32_t grid_dim_y = 1; uint32_t grid_dim_y = 1;
...@@ -567,20 +550,11 @@ template <typename ctype> ...@@ -567,20 +550,11 @@ template <typename ctype>
cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, cudaError_t topk::topk_select(const ctype* input, const ctype* thresh,
ctype* output_value, int32_t* output_idx, ctype* output_value, int32_t* output_idx,
void* workspace, uint32_t batch, uint32_t length, void* workspace, uint32_t batch, uint32_t length,
int32_t lda, int32_t k, cudaStream_t stream) { int32_t lda, int32_t k,
uint32_t batch_upper_limit, cudaStream_t stream) {
using namespace cuda_topk_impl; using namespace cuda_topk_impl;
using namespace cuda_topk_impl::select; using namespace cuda_topk_impl::select;
int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t batch_upper_limit = prop.maxGridSize[1];
uint32_t length_split = DIVUP(length, REDUCE_SIZE); uint32_t length_split = DIVUP(length, REDUCE_SIZE);
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t,
...@@ -688,10 +662,10 @@ namespace topk { ...@@ -688,10 +662,10 @@ namespace topk {
#define INST(t) \ #define INST(t) \
template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \ template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \
uint32_t, int32_t, int32_t, \ uint32_t, int32_t, int32_t, \
cudaStream_t); \ uint32_t, cudaStream_t); \
template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \ template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \
void*, uint32_t, uint32_t, int32_t, \ void*, uint32_t, uint32_t, int32_t, \
int32_t, cudaStream_t) int32_t, uint32_t, cudaStream_t)
INST(float); INST(float);
INST(int32_t); INST(int32_t);
#undef INST #undef INST
......
...@@ -76,10 +76,12 @@ struct RadixConverter<int32_t> { ...@@ -76,10 +76,12 @@ struct RadixConverter<int32_t> {
template <typename ctype> template <typename ctype>
cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace, cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace,
uint32_t batch, uint32_t length, int32_t lda, uint32_t batch, uint32_t length, int32_t lda,
int32_t k, cudaStream_t stream); int32_t k, uint32_t grid_dim_y_limit,
cudaStream_t stream);
//! get workspace in bytes //! get workspace in bytes
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length); uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length,
uint32_t grid_dim_y_limit);
/*! /*!
* \brief select values from rows of input that compare to thresh as specified * \brief select values from rows of input that compare to thresh as specified
...@@ -90,7 +92,8 @@ template <typename ctype> ...@@ -90,7 +92,8 @@ template <typename ctype>
cudaError_t topk_select(const ctype* input, const ctype* thresh, cudaError_t topk_select(const ctype* input, const ctype* thresh,
ctype* output_value, int32_t* output_idx, ctype* output_value, int32_t* output_idx,
void* workspace, uint32_t batch, uint32_t length, void* workspace, uint32_t batch, uint32_t length,
int32_t lda, int32_t k, cudaStream_t stream); int32_t lda, int32_t k, uint32_t batch_upper_limit,
cudaStream_t stream);
uint32_t topk_select_workspace(uint32_t batch, uint32_t length); uint32_t topk_select_workspace(uint32_t batch, uint32_t length);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册