diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 7ad49de4eed5e26cdc24a7444ead9a50abf54453..02b8aa04de132c4734567b9286fb094f07e9d310 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -288,7 +288,6 @@ struct SearchAlgorithm { ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto& temp = ctx.cuda_device_context(); AlgorithmsCache& algo_cache = *(framework::ConvSearchCache::Instance().GetForward()); diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/fluid/operators/math/im2col.cu index f616e116d0aee7ef0c450f00e5d14dc57d6f7438..344a28c6f03b8796134771459c06a6b7f4df962c 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/fluid/operators/math/im2col.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/backends/gpu/gpu_context.h" namespace paddle { namespace operators { @@ -73,12 +74,12 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height, * col = * [input_channels, filter_height, filter_width, output_height, output_width] */ -template -class Im2ColFunctor { +template +class Im2ColFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& im, const std::vector& dilation, + void operator()(const DeviceContext& context, const framework::Tensor& im, + const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* col, const DataLayout data_layout) { @@ -184,12 +185,11 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, * col = * [input_channels, filter_height, filter_width, output_height, output_width] */ -template -class Col2ImFunctor { +template +class Col2ImFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& col, + void operator()(const DeviceContext& context, const framework::Tensor& col, const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* im, @@ -257,10 +257,18 @@ template class Im2ColFunctor; template class Im2ColFunctor; +template class Im2ColFunctor; +template class Im2ColFunctor; template class Col2ImFunctor; template class Col2ImFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; template __global__ void im2colOCF(const T* im_data, int im_channels, int im_height, @@ -299,12 +307,12 @@ __global__ void im2colOCF(const T* im_data, int im_channels, int im_height, * col = * [output_height, output_width, input_channels, filter_height, filter_width] */ -template -class Im2ColFunctor { +template +class Im2ColFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& im, const std::vector& dilation, + void operator()(const DeviceContext& context, const framework::Tensor& im, + const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* col, const DataLayout data_layout) { @@ -390,12 +398,11 @@ __global__ void col2imOCF(const T* col_data, int im_channels, int im_height, * col = * [output_height, output_width, input_channels, filter_height, filter_width] */ -template -class Col2ImFunctor { +template +class Col2ImFunctor { public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& col, + void operator()(const DeviceContext& context, const framework::Tensor& col, const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* im, @@ -464,10 +471,19 @@ template class Im2ColFunctor; template class Im2ColFunctor; +template class Im2ColFunctor; +template class Im2ColFunctor; + template class Col2ImFunctor; template class Col2ImFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/vol2col.cu b/paddle/fluid/operators/math/vol2col.cu index d9c757544a9c6abbac12e0f09278987db854a694..b6f7bae4da37d807850a30afdc814b33519f93d8 100644 --- a/paddle/fluid/operators/math/vol2col.cu +++ b/paddle/fluid/operators/math/vol2col.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/pten/backends/gpu/gpu_context.h" namespace paddle { namespace operators { @@ -82,93 +83,91 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, * [input_channels, filter_depth, filter_height, filter_width, * output_depth, output_height, output_width] */ -template -class Vol2ColFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& vol, - const std::vector& dilations, - const std::vector& strides, - const std::vector& paddings, framework::Tensor* col, - const DataLayout data_layout) const { - PADDLE_ENFORCE_EQ(vol.dims().size(), 4, - platform::errors::InvalidArgument( - "The dimension of vol should be 4, but received %d.", - vol.dims().size())); - PADDLE_ENFORCE_EQ(col->dims().size(), 7, - platform::errors::InvalidArgument( - "The dimension of col should be 7, but received %d.", - col->dims().size())); +// template +// class Vol2ColFunctor { +// public: +template +void Vol2ColFunctor::operator()( + const DeviceContext& context, const framework::Tensor& vol, + const std::vector& dilations, const std::vector& strides, + const std::vector& paddings, framework::Tensor* col, + const DataLayout data_layout) const { + PADDLE_ENFORCE_EQ(vol.dims().size(), 4, + platform::errors::InvalidArgument( + "The dimension of vol should be 4, but received %d.", + vol.dims().size())); + PADDLE_ENFORCE_EQ(col->dims().size(), 7, + platform::errors::InvalidArgument( + "The dimension of col should be 7, but received %d.", + col->dims().size())); - int input_channels = - (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); - int input_depth = - (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]); - int input_height = - (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]); - int input_width = - (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]); - int filter_depth = col->dims()[1]; - int filter_height = col->dims()[2]; - int filter_width = col->dims()[3]; - int output_depth = col->dims()[4]; - int output_height = col->dims()[5]; - int output_width = col->dims()[6]; + int input_channels = + (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); + int input_depth = + (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]); + int input_height = + (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]); + int input_width = + (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]); + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; - bool paddings_size_is_6 = (paddings.size() == 6); - int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; - int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; - int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; - int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; - int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; - int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; - auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back - - ((dilations[0] * (filter_depth - 1) + 1))) / - strides[0] + - 1; - PADDLE_ENFORCE_EQ( - input_depth_tmp, output_depth, - platform::errors::InvalidArgument( - "input_depth(%d) and output_depth(%d) are mismatching.", - input_depth_tmp, output_depth)); - auto input_height_tmp = (input_height + pad_h_up + pad_h_down - - ((dilations[1] * (filter_height - 1) + 1))) / - strides[1] + - 1; - PADDLE_ENFORCE_EQ( - input_height_tmp, output_height, - platform::errors::InvalidArgument( - "input_height(%d) and output_height(%d) are mismatching.", - input_height_tmp, output_height)); - auto input_width_tmp = (input_width + pad_w_left + pad_w_right - - ((dilations[2] * (filter_width - 1) + 1))) / - strides[2] + - 1; - PADDLE_ENFORCE_EQ( - input_width_tmp, output_width, - platform::errors::InvalidArgument( - "input_width(%d) and output_width(%d) are mismatching.", - input_width_tmp, output_width)); + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + + 1; + PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth, + platform::errors::InvalidArgument( + "input_depth(%d) and output_depth(%d) are mismatching.", + input_depth_tmp, output_depth)); + auto input_height_tmp = (input_height + pad_h_up + pad_h_down - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + + 1; + PADDLE_ENFORCE_EQ( + input_height_tmp, output_height, + platform::errors::InvalidArgument( + "input_height(%d) and output_height(%d) are mismatching.", + input_height_tmp, output_height)); + auto input_width_tmp = (input_width + pad_w_left + pad_w_right - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + + 1; + PADDLE_ENFORCE_EQ(input_width_tmp, output_width, + platform::errors::InvalidArgument( + "input_width(%d) and output_width(%d) are mismatching.", + input_width_tmp, output_width)); - int num_outputs = - input_channels * output_depth * output_height * output_width; + int num_outputs = + input_channels * output_depth * output_height * output_width; - int max_threads = 1024; + int max_threads = 1024; #ifdef WITH_NV_JETSON - platform::ChangeThreadNum(context, &max_threads); + platform::ChangeThreadNum(context, &max_threads); #endif - const int threads = max_threads; - const int blocks = (num_outputs + max_threads - 1) / max_threads; + const int threads = max_threads; + const int blocks = (num_outputs + max_threads - 1) / max_threads; - vol2col<<>>( - num_outputs, vol.data(), input_depth, input_height, input_width, - dilations[0], dilations[1], dilations[2], filter_depth, filter_height, - filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, - pad_w_left, output_depth, output_height, output_width, col->data(), - data_layout); - } -}; + vol2col<<>>( + num_outputs, vol.data(), input_depth, input_height, input_width, + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, + pad_w_left, output_depth, output_height, output_width, col->data(), + data_layout); +} +// }; template __global__ void col2vol(int num_kernels, const T* data_col, int depth, @@ -249,98 +248,101 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, * [input_channels, filter_depth, filter_height, filter_width, * output_depth, output_height, output_width] */ -template -class Col2VolFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& col, - const std::vector& dilations, - const std::vector& strides, - const std::vector& paddings, framework::Tensor* vol, - const DataLayout data_layout) const { - PADDLE_ENFORCE_EQ(vol->dims().size(), 4, - platform::errors::InvalidArgument( - "The dimension of vol should be 4, but received %d.", - vol->dims().size())); - PADDLE_ENFORCE_EQ(col.dims().size(), 7, - platform::errors::InvalidArgument( - "The dimension of col should be 7, but received %d.", - col.dims().size())); +// template +// class Col2VolFunctor { +// public: +template +void Col2VolFunctor::operator()( + const DeviceContext& context, const framework::Tensor& col, + const std::vector& dilations, const std::vector& strides, + const std::vector& paddings, framework::Tensor* vol, + const DataLayout data_layout) const { + PADDLE_ENFORCE_EQ(vol->dims().size(), 4, + platform::errors::InvalidArgument( + "The dimension of vol should be 4, but received %d.", + vol->dims().size())); + PADDLE_ENFORCE_EQ(col.dims().size(), 7, + platform::errors::InvalidArgument( + "The dimension of col should be 7, but received %d.", + col.dims().size())); - int input_channels = - (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); - int input_depth = - (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]); - int input_height = - (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]); - int input_width = - (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]); - int filter_depth = col.dims()[1]; - int filter_height = col.dims()[2]; - int filter_width = col.dims()[3]; - int output_depth = col.dims()[4]; - int output_height = col.dims()[5]; - int output_width = col.dims()[6]; + int input_channels = + (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); + int input_depth = + (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]); + int input_height = + (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]); + int input_width = + (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]); + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; - bool paddings_size_is_6 = (paddings.size() == 6); - int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; - int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; - int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; - int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; - int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; - int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; + bool paddings_size_is_6 = (paddings.size() == 6); + int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; + int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; + int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; + int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; + int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; + int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; - auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back - - ((dilations[0] * (filter_depth - 1) + 1))) / - strides[0] + - 1; - PADDLE_ENFORCE_EQ( - input_depth_tmp, output_depth, - platform::errors::InvalidArgument( - "input_depth(%d) and output_depth(%d) are mismatching.", - input_depth_tmp, output_depth)); - auto input_height_tmp = (input_height + pad_h_up + pad_h_down - - ((dilations[1] * (filter_height - 1) + 1))) / - strides[1] + - 1; - PADDLE_ENFORCE_EQ( - input_height_tmp, output_height, - platform::errors::InvalidArgument( - "input_height(%d) and output_height(%d) are mismatching.", - input_height_tmp, output_height)); - auto input_width_tmp = (input_width + pad_w_left + pad_w_right - - ((dilations[2] * (filter_width - 1) + 1))) / - strides[2] + - 1; - PADDLE_ENFORCE_EQ( - input_width_tmp, output_width, - platform::errors::InvalidArgument( - "input_width(%d) and output_width(%d) are mismatching.", - input_width_tmp, output_width)); + auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + + 1; + PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth, + platform::errors::InvalidArgument( + "input_depth(%d) and output_depth(%d) are mismatching.", + input_depth_tmp, output_depth)); + auto input_height_tmp = (input_height + pad_h_up + pad_h_down - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + + 1; + PADDLE_ENFORCE_EQ( + input_height_tmp, output_height, + platform::errors::InvalidArgument( + "input_height(%d) and output_height(%d) are mismatching.", + input_height_tmp, output_height)); + auto input_width_tmp = (input_width + pad_w_left + pad_w_right - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + + 1; + PADDLE_ENFORCE_EQ(input_width_tmp, output_width, + platform::errors::InvalidArgument( + "input_width(%d) and output_width(%d) are mismatching.", + input_width_tmp, output_width)); - int num_kernels = input_channels * input_depth * input_height * input_width; + int num_kernels = input_channels * input_depth * input_height * input_width; - int max_threads = 1024; + int max_threads = 1024; #ifdef WITH_NV_JETSON - platform::ChangeThreadNum(context, &max_threads); + platform::ChangeThreadNum(context, &max_threads); #endif - const int threads = max_threads; - const int blocks = (num_kernels + max_threads - 1) / max_threads; + const int threads = max_threads; + const int blocks = (num_kernels + max_threads - 1) / max_threads; - col2vol<<>>( - num_kernels, col.data(), input_depth, input_height, input_width, - dilations[0], dilations[1], dilations[2], filter_depth, filter_height, - filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, - pad_w_left, output_depth, output_height, output_width, vol->data(), - data_layout); - } -}; + col2vol<<>>( + num_kernels, col.data(), input_depth, input_height, input_width, + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, + pad_w_left, output_depth, output_height, output_width, vol->data(), + data_layout); +} +// }; template class Vol2ColFunctor; template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Vol2ColFunctor; + template class Col2VolFunctor; template class Col2VolFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 1e674258334b027840cd8e1daad93ab8cbb1a618..966dcf7770da8c30a0bb95b4cbae6c9e31836e90 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/stream/cuda_stream.h" #include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/allocator.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" @@ -485,8 +486,11 @@ CUDAContext::~CUDAContext() { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : pten::GPUContext(place) { pten::GPUContext::PartialInitWithoutAllocator(); - cuda_stream_.reset( - new stream::CUDAStream(pten::GPUContext::stream(), this->GetPlace())); + cuda_stream_.reset(new stream::CUDAStream(pten::GPUContext::stream(), place)); + workspace_.reset(new pten::DnnWorkspaceHandle( + memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place, pten::GPUContext::stream()) + .get())); } CUDADeviceContext::~CUDADeviceContext() = default; @@ -571,8 +575,15 @@ void CUDADeviceContext::WaitStreamCallback() const { pten::GPUContext::WaitStreamCallback(); } -CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { - return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); +pten::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { + if (thread_ctx_.count(this)) { + // return workspace_.get(); + return pten::DnnWorkspaceHandle( + memory::allocation::AllocatorFacade::Instance() + .GetAllocator(GetPlace(), pten::GPUContext::stream()) + .get()); + } + return pten::GPUContext::cudnn_workspace_handle(); } gpuStream_t CUDADeviceContext::stream() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 4d469e92c04cdded0cd9cebe8ccb2f80ef6e6a76..80dcf6d2ec23cea4f375f54d5d9f1b6e24f382cb 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -566,7 +566,7 @@ class CUDADeviceContext : public pten::GPUContext { * workspace. Once the handle is destructed, the lock would be released. * CudnnWorkspaceHandle is an RAII object to implement thread-safe * sequential cudnn function calls. */ - CudnnWorkspaceHandle cudnn_workspace_handle() const; + pten::DnnWorkspaceHandle cudnn_workspace_handle() const; /*! \brief Return cuda stream in the device context. */ gpuStream_t stream() const; @@ -607,6 +607,7 @@ class CUDADeviceContext : public pten::GPUContext { // NOTE: Just for compatibility with the past, please delete if there is an // elegant way. std::unique_ptr cuda_stream_; + std::unique_ptr workspace_{nullptr}; DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; diff --git a/paddle/pten/backends/gpu/gpu_context.cc b/paddle/pten/backends/gpu/gpu_context.cc index 1e707c46cc93de22d8cac1f81197046d361fa4fa..0ddcfb85602f8dd611282e9ddbcf61b9612350e3 100644 --- a/paddle/pten/backends/gpu/gpu_context.cc +++ b/paddle/pten/backends/gpu/gpu_context.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/pten/backends/gpu/gpu_context.h" +#include #include #include #include @@ -153,55 +154,14 @@ static void StreamCallbackFunc(gpuStream_t stream, } // namespace internal -class DnnWorkspaceHandle { - public: - explicit inline DnnWorkspaceHandle(Allocator* allocator) - : allocator_(allocator) {} - - inline void RunFunc(const std::function& cudnn_func, - size_t required_workspace_bytes) { - if (required_workspace_bytes > WorkspaceSize()) { - ReallocWorkspace(required_workspace_bytes); - } - VLOG(2) << "Cudnn workspace size at RunFunc: " - << static_cast(WorkspaceSize()) / (1 << 20) << " MB"; - { - std::lock_guard guard(mtx_); - cudnn_func(allocation_ ? allocation_->ptr() : nullptr); - } - } - - /*! \brief Thread which call RunFuncSync() would release gpu memory after - * running the function. Currently this function is only used when cudnn - * exhaustive searching and callers have to guarantee that the input function - * is host blocking */ - inline void RunFuncSync(const std::function& cudnn_func, - size_t required_workspace_bytes) { - RunFunc(cudnn_func, required_workspace_bytes); - ResetWorkspace(); - } +void DnnWorkspaceHandle::ResetWorkspace() { allocation_ = nullptr; } - inline size_t WorkspaceSize() { - if (allocation_ == nullptr) { - return 0; - } - return allocation_->size(); - } - - void ResetWorkspace() { allocation_ = nullptr; } - - void ReallocWorkspace(size_t required_workspace_bytes) { - if (required_workspace_bytes <= WorkspaceSize()) return; - // reset allocation first before re-allocate to save memory - allocation_.reset(); - allocation_ = allocator_->Allocate(required_workspace_bytes); - } - - private: - Allocator::AllocationPtr allocation_{nullptr}; - Allocator* allocator_{nullptr}; - std::mutex mtx_; -}; +void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) { + if (required_workspace_bytes <= WorkspaceSize()) return; + // reset allocation first before re-allocate to save memory + allocation_.reset(); + allocation_ = allocator_->Allocate(required_workspace_bytes); +} struct GPUContext::Impl { void Init() { @@ -341,9 +301,15 @@ struct GPUContext::Impl { } } - DnnWorkspaceHandle* GetDnnWorkspace() { - PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr."); - return workspace_; + // TODO(wilber): The return type is a pointer, to be modified later. + // DnnWorkspaceHandle* GetDnnWorkspace() { + // PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr."); + // return workspace_; + // } + DnnWorkspaceHandle GetDnnWorkspace() { + PD_CHECK(allocator_ != nullptr, + "the device allocator for gpu context is nullptr."); + return DnnWorkspaceHandle(allocator_); } void InitStream() { @@ -797,7 +763,7 @@ Eigen::GpuDevice* GPUContext::eigen_device() const { return impl_->eigen_device(); } -DnnWorkspaceHandle* GPUContext::cudnn_workspace_handle() { +DnnWorkspaceHandle GPUContext::cudnn_workspace_handle() const { return impl_->GetDnnWorkspace(); } diff --git a/paddle/pten/backends/gpu/gpu_context.h b/paddle/pten/backends/gpu/gpu_context.h index 2a2be0e44b4f09506507ca97f17f2bfc59ca6130..d59773cfbff833e87dccfe462a7f4023c6404276 100644 --- a/paddle/pten/backends/gpu/gpu_context.h +++ b/paddle/pten/backends/gpu/gpu_context.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include "paddle/pten/backends/gpu/forwards.h" #include "paddle/pten/backends/gpu/gpu_decls.h" #include "paddle/pten/backends/gpu/gpu_helper.h" @@ -24,7 +25,53 @@ limitations under the License. */ namespace pten { -class DnnWorkspaceHandle; +class DnnWorkspaceHandle { + public: + explicit inline DnnWorkspaceHandle(Allocator* allocator) + : allocator_(allocator) { + mtx_.reset(new std::mutex()); + } + + inline void RunFunc(const std::function& cudnn_func, + size_t required_workspace_bytes) { + if (required_workspace_bytes > WorkspaceSize()) { + ReallocWorkspace(required_workspace_bytes); + } + { + std::lock_guard guard(*mtx_); + cudnn_func(allocation_ ? allocation_->ptr() : nullptr); + } + } + + /*! \brief Thread which call RunFuncSync() would release gpu memory after + * running the function. Currently this function is only used when cudnn + * exhaustive searching and callers have to guarantee that the input function + * is host blocking */ + inline void RunFuncSync(const std::function& cudnn_func, + size_t required_workspace_bytes) { + RunFunc(cudnn_func, required_workspace_bytes); + ResetWorkspace(); + } + + inline size_t WorkspaceSize() { + if (allocation_ == nullptr) { + return 0; + } + return allocation_->size(); + } + + void ResetWorkspace(); + + void ReallocWorkspace(size_t required_workspace_bytes); + + DnnWorkspaceHandle(DnnWorkspaceHandle&&) = default; + DnnWorkspaceHandle& operator=(DnnWorkspaceHandle&&) = delete; + + private: + Allocator::AllocationPtr allocation_{nullptr}; + Allocator* allocator_{nullptr}; + std::unique_ptr mtx_; +}; class GPUContext : public DeviceContext { public: @@ -85,7 +132,8 @@ class GPUContext : public DeviceContext { * would be acquired to prevent other threads from accessing the * workspace. Once the handle is destructed, the lock would be released. */ - DnnWorkspaceHandle* cudnn_workspace_handle(); + // TODO(wilber): The return type is a pointer, to be modified later. + DnnWorkspaceHandle cudnn_workspace_handle() const; public: /*! \brief Call cublas function safely. */