提交 131ec276 编写于 作者: C chengduoZH

fix bug for big number; float->double and code refine

上级 82bd82c1
......@@ -70,7 +70,7 @@ __global__ void KernelConcat(T** inputs, const int input_col,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
float inv_input_col = 1.0 / input_col;
double inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
......@@ -113,7 +113,7 @@ __global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols,
T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
float inv_input_col = 1.0 / input_col;
double inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
......@@ -145,8 +145,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
paddle::framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
paddle::framework::Vector<int> inputs_cols(num + 1);
framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
framework::Vector<int> inputs_cols(num + 1);
inputs_cols[0] = 0;
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
......@@ -168,15 +168,14 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
// computation
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = std::min(out_cols, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
int block_cols = kThreadsPerBlock;
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int dev_id = paddle::platform::GetCurrentDeviceId();
int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id);
int max_threads_per_mp =
paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id);
int max_threads = multi_process * max_threads_per_mp;
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
......@@ -218,8 +217,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
int input_col = 0;
bool sameShape = true;
paddle::framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
paddle::framework::Vector<int> outputs_cols(num + 1);
framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
framework::Vector<int> outputs_cols(num + 1);
outputs_cols[0] = 0;
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
......@@ -239,12 +238,20 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
// computation
const int kThreadsPerBlock = 1024;
int block_cols = std::min(input_col, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
int block_cols = kThreadsPerBlock;
if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((input_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int grid_cols = (input_col + block_cols - 1) / block_cols;
int grid_rows = (input_row + block_rows - 1) / block_rows;
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((input_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
......
......@@ -121,6 +121,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
SetDeviceId(place_.device);
multi_process = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
......@@ -154,6 +156,10 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaGetLastError());
}
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
return multi_process * max_threads_per_mp;
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get();
}
......
......@@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return place in the device context. */
Place GetPlace() const override;
/*! \brief Return the max physical thread count in the device context */
int GetMaxPhysicalThreadCount() const;
/*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const;
......@@ -100,6 +103,9 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
int multi_process;
int max_threads_per_mp;
};
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册