From 3d4f2aa6897dee09240ffef777fe98ecead7fae9 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Wed, 8 Jan 2020 14:01:10 +0800 Subject: [PATCH] Refine stack op to improve xlnet performance, test=develop (#22142) stack's wait cost a lot of cpu time, use cuda kernel to do memory copy will reduce cpu time. Signed-off-by: zhaoyuchen --- paddle/fluid/operators/stack_op.cc | 134 ++++++++++++ paddle/fluid/operators/stack_op.cu | 194 ++++++++++++++++-- paddle/fluid/operators/stack_op.h | 146 ------------- paddle/fluid/platform/device_context.cc | 7 + paddle/fluid/platform/device_context.h | 7 + paddle/fluid/platform/gpu_info.cc | 19 ++ paddle/fluid/platform/gpu_info.h | 3 + .../fluid/platform/gpu_launch_param_config.h | 103 ++++++++++ 8 files changed, 454 insertions(+), 159 deletions(-) create mode 100755 paddle/fluid/platform/gpu_launch_param_config.h diff --git a/paddle/fluid/operators/stack_op.cc b/paddle/fluid/operators/stack_op.cc index b2e46cccfc..7e65a7a1d0 100644 --- a/paddle/fluid/operators/stack_op.cc +++ b/paddle/fluid/operators/stack_op.cc @@ -13,9 +13,143 @@ // limitations under the License. #include "paddle/fluid/operators/stack_op.h" +#include +#include namespace plat = paddle::platform; namespace ops = paddle::operators; + +namespace paddle { +namespace operators { + +class StackOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0, + platform::errors::InvalidArgument( + "Number of Inputs(X) must be larger than 0, but" + " received value is:%d.", + ctx->Inputs("X").size())); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true, + platform::errors::InvalidArgument( + "Output(Y) of stack_op should not be null.")); + + auto input_dims = ctx->GetInputsDim("X"); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + platform::errors::InvalidArgument( + "Dims of all Inputs(X) must be the same, but" + " received input %d dim is:%d not equal to input 0" + " dim:%d.", + i, input_dims[i], input_dims[0])); + } + + // Only lod of X[0] would be shared with Y + ctx->ShareLoD("X", /*->*/ "Y"); + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE_GE( + axis, -(rank + 1), + platform::errors::InvalidArgument( + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, " + "but received axis is:%d.", + rank, axis)); + + PADDLE_ENFORCE_LT( + axis, rank + 1, + platform::errors::InvalidArgument( + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, " + "but received axis is:%d", + rank, axis)); + + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim("Y", framework::make_ddim(vec)); + } +}; + +class StackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of stack op.").AsDuplicable(); + AddOutput("Y", "The output of stack op."); + AddAttr("axis", + "The axis along which all of the Inputs(X) should be stacked.") + .SetDefault(0); + AddComment(R"DOC( +Stack Operator. +Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same. +)DOC"); + } +}; + +class StackOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Y")), true, + platform::errors::InvalidArgument("Input(Y@Grad) not exist.")); + + int axis = ctx->Attrs().Get("axis"); + auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = dy_dim.size(); + PADDLE_ENFORCE_GE( + axis, -rank, + platform::errors::InvalidArgument( + "Attr(axis) must be inside [-rank, rank), where rank = %d, " + "but received axis is:%d.", + rank, axis)); + PADDLE_ENFORCE_LT( + axis, rank, + platform::errors::InvalidArgument( + "Attr(axis) must be inside [-rank, rank), where rank = %d, " + "but received axis is:%d.", + rank, axis)); + + if (axis < 0) axis += rank; + PADDLE_ENFORCE_EQ( + ctx->Outputs(framework::GradVarName("X")).size(), + static_cast(dy_dim[axis]), + platform::errors::InvalidArgument( + "Number of Outputs(X@Grad) is equal to dy dim at axis, but" + " received outputs size is:%d, dy dims is:%d.", + ctx->Outputs(framework::GradVarName("X")).size(), + static_cast(dy_dim[axis]))); + + auto vec = framework::vectorize(dy_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim( + framework::GradVarName("X"), + std::vector(dy_dim[axis], framework::make_ddim(vec))); + } +}; + +template +class StackGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new T()); + op->SetType("stack_grad"); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); + op->SetAttrMap(this->Attrs()); + return op; + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, ops::StackGradOpMaker, ops::StackGradOpMaker); diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu index 24d0b2f906..e7bc91f5c8 100644 --- a/paddle/fluid/operators/stack_op.cu +++ b/paddle/fluid/operators/stack_op.cu @@ -12,21 +12,189 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include #include "paddle/fluid/operators/stack_op.h" +#include "paddle/fluid/platform/gpu_launch_param_config.h" namespace plat = paddle::platform; namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - stack, ops::StackKernel, - ops::StackKernel, - ops::StackKernel, - ops::StackKernel, - ops::StackKernel); - -REGISTER_OP_CUDA_KERNEL( - stack_grad, ops::StackGradKernel, - ops::StackGradKernel, - ops::StackGradKernel, - ops::StackGradKernel, - ops::StackGradKernel); +namespace paddle { +namespace operators { + +template +__global__ void StackCUDAKernel(T** input_ptrs, int split_size, int rows, + int cols, T* __restrict__ output) { + IntType grid_x = blockIdx.x * blockDim.x + threadIdx.x; + + for (; grid_x < cols; grid_x += blockDim.x * gridDim.x) { + IntType grid_y = blockIdx.y * blockDim.y + threadIdx.y; + + IntType split = grid_x / split_size; + const T* input_ptr = input_ptrs[split]; + IntType col_offset = grid_x % split_size; +#pragma unroll + for (; grid_y < rows; grid_y += blockDim.y * gridDim.y) { + output[grid_y * cols + grid_x] = + input_ptr[grid_y * split_size + col_offset]; + } + } +} + +template +class StackGPUKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.MultiInput("X"); + auto* y = ctx.Output("Y"); + + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + + int n = static_cast(x.size()); + auto* y_data = y->mutable_data(ctx.GetPlace()); + std::vector x_datas(n); + for (int i = 0; i < n; i++) { + x_datas[i] = x[i]->data(); + } + + auto& dev_ctx = ctx.template device_context(); + auto tmp_x_data = memory::Alloc(dev_ctx, x_datas.size() * sizeof(T*)); + memory::Copy(boost::get(dev_ctx.GetPlace()), + tmp_x_data->ptr(), platform::CPUPlace(), + reinterpret_cast(x_datas.data()), + x_datas.size() * sizeof(T*), dev_ctx.stream()); + + // Split x dim from axis to matrix + int x_row = 1, x_col = 1; + for (int i = 0; i < axis; ++i) { + x_row *= x[0]->dims()[i]; + } + x_col = x[0]->numel() / x_row; + int out_col = x_col * n; + + auto config = GetGpuLaunchConfig2D(dev_ctx, out_col, x_row); + + if (y->numel() < std::numeric_limits::max()) { + StackCUDAKernel<<>>( + reinterpret_cast(tmp_x_data->ptr()), x_col, x_row, out_col, + y_data); + } else { + StackCUDAKernel<<>>( + reinterpret_cast(tmp_x_data->ptr()), x_col, x_row, out_col, + y_data); + } + } +}; + +template +__global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size, + int split_dim_size, int suf_dim_size, + int num_split, T** output_ptrs) { + assert(blockDim.y == 1); + assert(blockDim.z == 1); + // In this case they are equal + assert(split_dim_size % num_split == 0); + + IntType size = pre_dim_size * split_dim_size * suf_dim_size; + IntType each_dim_size = split_dim_size / num_split; + + for (IntType offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size; + offset += blockDim.x * gridDim.x) { + IntType i = offset / (split_dim_size * suf_dim_size); + IntType j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; + IntType k = offset % suf_dim_size; + + T* output = output_ptrs[j / each_dim_size]; + IntType output_ind = i * each_dim_size * suf_dim_size + + (j % each_dim_size) * suf_dim_size + k; + *(output + output_ind) = input[offset]; + } +} + +template +class StackGradGPUKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto dx = ctx.MultiOutput(framework::GradVarName("X")); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += dy->dims().size(); + + int n = dy->dims()[axis]; + PADDLE_ENFORCE_EQ(n, dx.size(), + platform::errors::InvalidArgument( + "Output dx size should be equal to n, but" + " received n is:%d dx size is:%d.", + n, dx.size())); + + // dx is output, so save each data address, then copy each dy into dx_data + std::vector outputs(n); + auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); + for (size_t j = 0; j < dx.size(); ++j) { + if (out_var_names[j] != framework::kEmptyVarName && + dx[j]->numel() != 0UL) { + T* ptr = dx[j]->mutable_data(ctx.GetPlace()); + outputs[j] = ptr; + } else { + outputs[j] = nullptr; + } + } + auto dy_data = dy->data(); + // each dx should have same shape + int dy_pre = 1, dy_suf = 1; + auto dy_dims = dy->dims(); + int split_dim = n; + for (int i = 0; i < axis; ++i) { + dy_pre *= dy_dims[i]; + } + dy_suf = dy->numel() / (split_dim * dy_pre); + + auto& dev_ctx = ctx.template device_context(); + auto tmp_out_data = memory::Alloc(dev_ctx, outputs.size() * sizeof(T*)); + memory::Copy(boost::get(dev_ctx.GetPlace()), + tmp_out_data->ptr(), platform::CPUPlace(), + reinterpret_cast(outputs.data()), + outputs.size() * sizeof(T*), dev_ctx.stream()); + + auto config = GetGpuLaunchConfig1D(dev_ctx, dy_pre * split_dim * dy_suf); + + if (dy->numel() < std::numeric_limits::max()) { + UnStackCUDAKernel< + T, int32_t><<>>( + dy_data, dy_pre, split_dim, dy_suf, split_dim, + reinterpret_cast(tmp_out_data->ptr())); + } else { + UnStackCUDAKernel< + T, int64_t><<>>( + dy_data, dy_pre, split_dim, dy_suf, split_dim, + reinterpret_cast(tmp_out_data->ptr())); + } + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(stack, ops::StackGPUKernel, + ops::StackGPUKernel, ops::StackGPUKernel, + ops::StackGPUKernel, + ops::StackGPUKernel); + +REGISTER_OP_CUDA_KERNEL(stack_grad, ops::StackGradGPUKernel, + ops::StackGradGPUKernel, + ops::StackGradGPUKernel, + ops::StackGradGPUKernel, + ops::StackGradGPUKernel); diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index 9cb391c048..38ab60afd9 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -18,80 +18,9 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" -#ifdef __NVCC__ -#include -#include "paddle/fluid/framework/array.h" -#endif - namespace paddle { namespace operators { -class StackOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0, - "Number of Inputs(X) must be larger than 0"); - PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist."); - - auto input_dims = ctx->GetInputsDim("X"); - for (size_t i = 1; i < input_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], - "Dims of all Inputs(X) must be the same"); - } - - // Only lod of X[0] would be shared with Y - ctx->ShareLoD("X", /*->*/ "Y"); - - int axis = ctx->Attrs().Get("axis"); - int rank = input_dims[0].size(); - PADDLE_ENFORCE( - axis >= -(rank + 1) && axis < rank + 1, - "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); - if (axis < 0) axis += (rank + 1); - - auto vec = framework::vectorize(input_dims[0]); - vec.insert(vec.begin() + axis, input_dims.size()); - ctx->SetOutputDim("Y", framework::make_ddim(vec)); - } -}; - -class StackOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input of stack op.").AsDuplicable(); - AddOutput("Y", "The output of stack op."); - AddAttr("axis", - "The axis along which all of the Inputs(X) should be stacked.") - .SetDefault(0); - AddComment(R"DOC( - Stack Operator. - - Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same. - )DOC"); - } -}; - -template -struct StackFunctor { - HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post) - : x_(x), y_(y), n_(n), post_(post) {} - - HOSTDEVICE void operator()(int idx) { - int i = idx / (n_ * post_); - int which_x = idx / post_ - i * n_; - int x_index = i * post_ + idx % post_; - y_[idx] = x_[which_x][x_index]; - } - - private: - VecXType x_; - T *y_; - int n_; - int post_; -}; - template struct StackGradFunctor { HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) @@ -111,14 +40,6 @@ struct StackGradFunctor { int post_; }; -template -static inline void StackFunctorForRange(const DeviceContext &ctx, - const VecXType &x, T *y, int total_num, - int n, int post) { - platform::ForRange for_range(ctx, total_num); - for_range(StackFunctor(x, y, n, post)); -} - template static inline void StackGradFunctorForRange(const DeviceContext &ctx, const VecDxType &dx, const T *dy, @@ -149,19 +70,6 @@ class StackKernel : public framework::OpKernel { for (auto i = 0; i < axis; ++i) pre *= dim[i]; for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; -#ifdef __NVCC__ - int total_num = pre * n * post; - auto &dev_ctx = ctx.template device_context(); - - thrust::device_vector device_x_vec(x_datas); - auto x_data_arr = device_x_vec.data().get(); - - StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); - - // Wait() must be called because device_x_vec may be destructed before - // kernel ends - dev_ctx.Wait(); -#else auto x_data_arr = x_datas.data(); size_t x_offset = 0; @@ -174,50 +82,6 @@ class StackKernel : public framework::OpKernel { } x_offset += post; } -#endif - } -}; - -class StackOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), - "Input(Y@Grad) must exist."); - - int axis = ctx->Attrs().Get("axis"); - auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y")); - int rank = dy_dim.size(); - PADDLE_ENFORCE(axis >= -rank && axis < rank, - "Attr(axis) must be inside [-rank, rank), where rank = %d", - rank); - if (axis < 0) axis += rank; - - PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(), - static_cast(dy_dim[axis]), - "Number of Outputs(X@Grad) is wrong"); - auto vec = framework::vectorize(dy_dim); - vec.erase(vec.begin() + axis); - ctx->SetOutputsDim( - framework::GradVarName("X"), - std::vector(dy_dim[axis], framework::make_ddim(vec))); - } -}; - -template -class StackGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - std::unique_ptr Apply() const override { - std::unique_ptr op(new T()); - op->SetType("stack_grad"); - op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); - op->SetAttrMap(this->Attrs()); - return op; } }; @@ -245,18 +109,8 @@ class StackGradKernel : public framework::OpKernel { int post = total_num / (n * pre); auto &dev_ctx = ctx.template device_context(); -#ifdef __NVCC__ - thrust::device_vector device_dx_vec(dx_datas); - auto dx_data_arr = device_dx_vec.data().get(); -#else auto dx_data_arr = dx_datas.data(); -#endif StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); -#ifdef __NVCC__ - // Wait() must be called because device_dx_vec may be destructed before - // kernel ends - dev_ctx.Wait(); -#endif } }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index b038f68738..62c83e0c4e 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -217,6 +217,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { multi_process_ = GetCUDAMultiProcessors(place_.device); max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device); max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device); + max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); @@ -338,6 +339,12 @@ int CUDADeviceContext::GetMaxPhysicalThreadCount() const { return multi_process_ * max_threads_per_mp_; } +int CUDADeviceContext::GetSMCount() const { return multi_process_; } + +int CUDADeviceContext::GetMaxThreadsPerBlock() const { + return max_threads_per_block_; +} + Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 344ac69f97..50e4538987 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -97,6 +97,12 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return the max physical thread count in the device context */ int GetMaxPhysicalThreadCount() const; + /*! \brief Return the SM count in the device context */ + int GetSMCount() const; + + /*! \brief Return the Max thread num of block in the device context */ + int GetMaxThreadsPerBlock() const; + /*! \brief Return the max grid dim size in the device context */ dim3 GetCUDAMaxGridDimSize() const; @@ -188,6 +194,7 @@ class CUDADeviceContext : public DeviceContext { int driver_version_; int multi_process_; int max_threads_per_mp_; + int max_threads_per_block_; dim3 max_grid_dim_size_; // StreamCallbackManager is thread-safe diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index a1020bf2d4..b3f00bf7c0 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -183,6 +183,25 @@ int GetCUDAMaxThreadsPerMultiProcessor(int id) { return count; } +int GetCUDAMaxThreadsPerBlock(int id) { + PADDLE_ENFORCE_LT( + id, GetCUDADeviceCount(), + platform::errors::InvalidArgument( + "Device id must less than GPU count, but received id is:%d, " + "GPU count is: %d.", + id, GetCUDADeviceCount())); + int count; + auto error_code = + cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id); + PADDLE_ENFORCE_EQ( + error_code, 0, + platform::errors::InvalidArgument( + "cudaDeviceGetAttribute returned error code should be 0, " + "but received error code is: %d, %s", + error_code, CudaErrorWebsite())); + return count; +} + int GetCurrentDeviceId() { int device_id; auto error_code = cudaGetDevice(&device_id); diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index ca917c924f..6ed2b344b9 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -45,6 +45,9 @@ int GetCUDAMultiProcessors(int i); //! Get the MaxThreads of each MultiProcessor of the ith GPU. int GetCUDAMaxThreadsPerMultiProcessor(int i); +//! Get the MaxThreads of each block of the ith GPU. +int GetCUDAMaxThreadsPerBlock(int i); + //! Get the current GPU device id in system. int GetCurrentDeviceId(); diff --git a/paddle/fluid/platform/gpu_launch_param_config.h b/paddle/fluid/platform/gpu_launch_param_config.h new file mode 100755 index 0000000000..c1ea063360 --- /dev/null +++ b/paddle/fluid/platform/gpu_launch_param_config.h @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Used for compute gpu launch parameter + +#pragma once + +#ifdef PADDLE_WITH_CUDA + +#include +#include +#include +#include +#include + +namespace paddle { +namespace platform { + +inline int DivUp(int a, int b) { return (a + b - 1) / b; } + +struct GpuLaunchParamConfig { + dim3 theory_thread_count = dim3(0, 0, 0); + dim3 thread_per_block = dim3(0, 0, 0); + dim3 block_per_grid = dim3(0, 0, 0); +}; + +inline GpuLaunchParamConfig GetGpuLaunchConfig1D( + const platform::CUDADeviceContext& context, int element_count) { + PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument( + "element count should greater than 0," + " but received value is:%d", + element_count)); + + const int theory_thread_count = element_count; + // Get Max threads in all SM + int max_pyhsical_threads = context.GetMaxPhysicalThreadCount(); + int sm = context.GetSMCount(); + + // Compute pyhsical threads we need, should small than max sm threads + const int physical_thread_count = + std::min(max_pyhsical_threads, theory_thread_count); + + // Need get from device + const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock()); + // Suppose block count small than factor * sm, factor is a experiments value. + int factor = 4; + const int block_count = + std::min(DivUp(physical_thread_count, thread_per_block), factor * sm); + + GpuLaunchParamConfig config; + config.theory_thread_count.x = theory_thread_count; + config.thread_per_block.x = thread_per_block; + config.block_per_grid.x = block_count; + return config; +} + +inline GpuLaunchParamConfig GetGpuLaunchConfig2D( + const platform::CUDADeviceContext& context, int xdim, int ydim) { + PADDLE_ENFORCE_GT(xdim, 0, platform::errors::InvalidArgument( + "x dim number should greater than 0," + " but received value is:%d", + xdim)); + PADDLE_ENFORCE_GT(ydim, 0, platform::errors::InvalidArgument( + "y dim number should greater than 0," + " but received value is:%d", + ydim)); + + const int kThreadsPerBlock = 256; + int block_cols = std::min(xdim, kThreadsPerBlock); + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + + int max_physical_threads = context.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1); + + GpuLaunchParamConfig config; + // Noticed, block size is not align to 32, if needed do it yourself. + config.theory_thread_count = dim3(xdim, ydim, 1); + config.thread_per_block = dim3(block_cols, block_rows, 1); + + int grid_x = std::min(DivUp(xdim, block_cols), max_blocks); + int grid_y = std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)); + + config.block_per_grid = dim3(grid_x, grid_y, 1); + return config; +} + +// 3D will add later + +} // namespace platform +} // namespace paddle + +#endif -- GitLab