提交 c774bcbd 编写于 作者: Y Yu Yang

Merge device_context

test=develop
上级 057a682e
...@@ -160,29 +160,26 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -160,29 +160,26 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
}; };
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) { : workspace_(nullptr), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
} }
CudnnHolder::~CudnnHolder() { CudnnHolder::~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
if (workspace_ != nullptr) {
paddle::memory::Free(place_, workspace_);
}
} }
void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) { void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= workspace_len_) { if (required_workspace_len <= WorkspaceSize()) {
return; return;
} }
if (workspace_ != nullptr) { if (workspace_ != nullptr) {
// Maybe someone is using the current workspace // Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
paddle::memory::Free(place_, workspace_); workspace_.reset();
} }
workspace_ = paddle::memory::Alloc(place_, required_workspace_len); workspace_ = paddle::memory::Alloc(place_, required_workspace_len,
workspace_len_ = required_workspace_len; paddle::memory::Allocator::kScratchpad);
} }
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/memory/malloc.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
...@@ -85,17 +85,32 @@ class CudnnHolder { ...@@ -85,17 +85,32 @@ class CudnnHolder {
template <typename Callback> template <typename Callback>
void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) { void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
if (required_workspace_len > workspace_len_) { if (required_workspace_len > WorkspaceSize()) {
ReallocateWorkspace(required_workspace_len); ReallocateWorkspace(required_workspace_len);
} }
cudnn_func(workspace_); cudnn_func(WorkspacePtr());
}
inline void* WorkspacePtr() {
if (workspace_) {
return workspace_->ptr();
} else {
return nullptr;
}
}
inline size_t WorkspaceSize() {
if (workspace_) {
return workspace_->size();
} else {
return 0;
}
} }
std::mutex& Mutex() { return mtx_; } std::mutex& Mutex() { return mtx_; }
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
void* workspace_; std::unique_ptr<memory::Allocation> workspace_;
size_t workspace_len_;
const cudaStream_t* stream_; // not owned; const cudaStream_t* stream_; // not owned;
const CUDAPlace place_; const CUDAPlace place_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册