提交 15cc9128 编写于 作者: F fengjiayi

fix compile error

上级 407ff0bd
...@@ -145,9 +145,9 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -145,9 +145,9 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
class CudnnHolder { class CudnnHolder {
public: public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place) CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: stream_(stream), place_(place), workspace_(nullptr), workspace_len_(0) { : workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
} }
cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; } cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; }
...@@ -157,14 +157,14 @@ class CudnnHolder { ...@@ -157,14 +157,14 @@ class CudnnHolder {
void* new_workspace = paddle::memory::Alloc(place_, required_len); void* new_workspace = paddle::memory::Alloc(place_, required_len);
if (workspace_ != nullptr) { if (workspace_ != nullptr) {
// Maybe someone is using the current workspace // Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
paddle::memory::Free(place_, workspace_); paddle::memory::Free(place_, workspace_);
} }
workspace_ = new_workspace; workspace_ = new_workspace;
workspace_len_ = required_len; workspace_len_ = required_len;
} }
return workspace_ return workspace_;
} }
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
...@@ -231,7 +231,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { ...@@ -231,7 +231,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder_->get_cudnn_handle(); return cudnn_holder_->get_cudnn_handle();
} }
void* cudnn_workspace(size_t required_len) const { void* CUDADeviceContext::cudnn_workspace(size_t required_len) const {
return cudnn_holder_->get_workspace(required_len); return cudnn_holder_->get_workspace(required_len);
} }
......
...@@ -69,7 +69,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> { ...@@ -69,7 +69,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
class CUDNNHolder; class CudnnHolder;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册