未验证 提交 b93870e6 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Improve topk performance. (#21087)

* Improve topk performance.

give 200000 data to compute topk,
before opt: cost 1s
after opt: cost 0.0028s.

* Refine return value.
* Add cuda util funtions.
* Fix ComputeBlockSize bug & refine comments.
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 d74ea085
...@@ -42,6 +42,11 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -42,6 +42,11 @@ class TopkOp : public framework::OperatorWithKernel {
framework::DDim dims = input_dims; framework::DDim dims = input_dims;
dims[dims.size() - 1] = k; dims[dims.size() - 1] = k;
// If has K as tensor, set k=-1 as not know real size at this time.
if (ctx->HasInput("K")) {
dims[dims.size() - 1] = -1;
}
ctx->SetOutputDim("Out", dims); ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims); ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
......
...@@ -12,11 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
paddle::platform::float16> {};
} // namespace cub
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -303,6 +312,157 @@ inline static int GetDesiredBlockDim(int dim) { ...@@ -303,6 +312,157 @@ inline static int GetDesiredBlockDim(int dim) {
} }
} }
// Iter for move to next row
struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
return idx * num_cols_;
}
int num_cols_;
};
// Iter using into a column
struct ColumnIndexIter {
explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
const Eigen::array<int, 1>& ix) const {
return ix[0] % num_cols_;
}
int num_cols_;
};
__global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (int j = row_id; j < num_rows; j += gridDim.x) {
for (int i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
}
template <typename T>
bool SortTopk(const platform::CUDADeviceContext& ctx,
const framework::Tensor* input_tensor, const size_t num_cols,
const size_t num_rows, size_t k, framework::Tensor* out_tensor,
framework::Tensor* indices_tensor) {
auto cu_stream = ctx.stream();
Tensor input_indices;
const std::vector<int64_t> dims = {static_cast<int64_t>(num_rows),
static_cast<int64_t>(num_cols)};
auto dim = framework::make_ddim(dims);
input_indices.Resize(dim);
// input_indices.Resize(num_rows*num_cols);
input_indices.mutable_data<int64_t>(ctx.GetPlace());
size_t temp_storage_bytes = -1;
auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(num_cols);
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
// Init a index array
InitIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
// create iter for counting input
cub::CountingInputIterator<int> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<int, SegmentOffsetIter,
cub::CountingInputIterator<int>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
T* sorted_values_ptr;
int64_t* sorted_indices_ptr;
Tensor temp_values;
Tensor temp_indices;
const T* input = input_tensor->data<T>();
T* values = out_tensor->data<T>();
int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());
if (k == num_cols) {
// Doing a full sort.
sorted_values_ptr = values;
sorted_indices_ptr = indices;
} else {
temp_values.Resize(dim);
temp_indices.Resize(dim);
sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
}
// Get temp storage buffer size, maybe can allocate a fixed buffer to save
// time.
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, input, sorted_values_ptr,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
if (err != cudaSuccess) {
LOG(ERROR)
<< "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
"temp_storage_bytes, status: "
<< cudaGetErrorString(err);
return false;
}
Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(), temp_storage_bytes, input,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream);
if (err != cudaSuccess) {
LOG(ERROR)
<< "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
"temp_storage_bytes: "
<< temp_storage_bytes << ", status: " << cudaGetErrorString(err);
return false;
}
auto& dev = *ctx.eigen_device();
if (k < num_cols) {
// copy sliced data to output.
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
auto e_indices = EigenMatrix<int64_t>::From(*indices_tensor, dim);
auto e_tmp_indices = EigenMatrix<int64_t>::From(temp_indices);
std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
auto dim = framework::make_ddim(odims);
auto e_values = EigenMatrix<T>::From(*out_tensor, dim);
auto e_tmp_values = EigenMatrix<T>::From(temp_values);
e_indices.device(dev) = e_tmp_indices.slice(slice_indices, slice_sizes);
e_values.device(dev) = e_tmp_values.slice(slice_indices, slice_sizes);
}
return true;
}
#define FIXED_BLOCK_DIM_BASE(dim, ...) \ #define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \ case (dim): { \
constexpr auto kBlockDim = (dim); \ constexpr auto kBlockDim = (dim); \
...@@ -340,13 +500,24 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -340,13 +500,24 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
framework::DDim inputdims = input->dims(); framework::DDim inputdims = input->dims();
const size_t input_height = framework::product( const size_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1]; const size_t input_width = inputdims[inputdims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
indices)) {
// Successed, return.
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
if (k > input_width) k = input_width; if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width. // NOTE: pass lds and dim same to input width.
...@@ -354,7 +525,6 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -354,7 +525,6 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// TODO(typhoonzero): refine this kernel. // TODO(typhoonzero): refine this kernel.
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context();
switch (GetDesiredBlockDim(input_width)) { switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5, KeMatrixTopK<T, 5,
......
...@@ -216,6 +216,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { ...@@ -216,6 +216,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
compute_capability_ = GetCUDAComputeCapability(place_.device); compute_capability_ = GetCUDAComputeCapability(place_.device);
multi_process_ = GetCUDAMultiProcessors(place_.device); multi_process_ = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device); max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place); eigen_stream_->Reinitialize(&stream_, place);
...@@ -342,6 +343,10 @@ bool CUDADeviceContext::tensor_core_available() const { ...@@ -342,6 +343,10 @@ bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr; return cublas_tensor_core_handle_ != nullptr;
} }
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
return max_grid_dim_size_;
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
......
...@@ -96,6 +96,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -96,6 +96,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return the max physical thread count in the device context */ /*! \brief Return the max physical thread count in the device context */
int GetMaxPhysicalThreadCount() const; int GetMaxPhysicalThreadCount() const;
/*! \brief Return the max grid dim size in the device context */
dim3 GetCUDAMaxGridDimSize() const;
/*! \brief Return eigen device in the device context. */ /*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const; Eigen::GpuDevice* eigen_device() const;
...@@ -184,6 +187,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -184,6 +187,7 @@ class CUDADeviceContext : public DeviceContext {
int driver_version_; int driver_version_;
int multi_process_; int multi_process_;
int max_threads_per_mp_; int max_threads_per_mp_;
dim3 max_grid_dim_size_;
// StreamCallbackManager is thread-safe // StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
......
...@@ -99,6 +99,33 @@ int GetCUDAComputeCapability(int id) { ...@@ -99,6 +99,33 @@ int GetCUDAComputeCapability(int id) {
return major * 10 + minor; return major * 10 + minor;
} }
dim3 GetGpuMaxGridDimSize(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
dim3 ret;
int size;
auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
PADDLE_ENFORCE_EQ(error_code_x, 0,
"cudaDevAttrMaxGridDimX failed in "
"paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
error_code_x, CudaErrorWebsite());
ret.x = size;
auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
PADDLE_ENFORCE_EQ(error_code_y, 0,
"cudaDevAttrMaxGridDimY failed in "
"paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
error_code_y, CudaErrorWebsite());
ret.y = size;
auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
PADDLE_ENFORCE_EQ(error_code_z, 0,
"cudaDevAttrMaxGridDimZ failed in "
"paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
error_code_z, CudaErrorWebsite());
ret.z = size;
return ret;
}
int GetCUDARuntimeVersion(int id) { int GetCUDARuntimeVersion(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
int runtime_version = 0; int runtime_version = 0;
......
...@@ -48,6 +48,9 @@ int GetCUDAMaxThreadsPerMultiProcessor(int i); ...@@ -48,6 +48,9 @@ int GetCUDAMaxThreadsPerMultiProcessor(int i);
//! Get the current GPU device id in system. //! Get the current GPU device id in system.
int GetCurrentDeviceId(); int GetCurrentDeviceId();
//! Get the maximum GridDim size for GPU buddy allocator.
dim3 GetGpuMaxGridDimSize(int);
//! Get a list of device ids from environment variable or use all. //! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices(); std::vector<int> GetSelectedDevices();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册