提交 29f66c24 编写于 作者: Y Yu Yang

Polish code

上级 e25240c2
...@@ -167,7 +167,7 @@ class CudnnHolder { ...@@ -167,7 +167,7 @@ class CudnnHolder {
if (required_workspace_len > WorkspaceSize()) { if (required_workspace_len > WorkspaceSize()) {
ReallocateWorkspace(required_workspace_len); ReallocateWorkspace(required_workspace_len);
} }
cudnn_func(workspace_->ptr()); cudnn_func(WorkspacePtr());
} }
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
...@@ -181,6 +181,14 @@ class CudnnHolder { ...@@ -181,6 +181,14 @@ class CudnnHolder {
} }
} }
void* WorkspacePtr() const {
if (workspace_ == nullptr) {
return nullptr;
} else {
return workspace_->ptr();
}
}
void ReallocateWorkspace(size_t required_workspace_len) { void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= WorkspaceSize()) { if (required_workspace_len <= WorkspaceSize()) {
return; return;
......
...@@ -99,7 +99,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -99,7 +99,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
py_buffer->shape = reinterpret_cast<Py_ssize_t *>( py_buffer->shape = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * tensor.dims().size())); malloc(sizeof(Py_ssize_t) * tensor.dims().size()));
for (size_t i = 0; i < tensor.dims().size(); ++i) { for (int i = 0; i < tensor.dims().size(); ++i) {
py_buffer->shape[i] = tensor.dims()[i]; py_buffer->shape[i] = tensor.dims()[i];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册