提交 29f66c24 编写于 作者: Y Yu Yang

Polish code

上级 e25240c2
......@@ -167,7 +167,7 @@ class CudnnHolder {
if (required_workspace_len > WorkspaceSize()) {
ReallocateWorkspace(required_workspace_len);
}
cudnn_func(workspace_->ptr());
cudnn_func(WorkspacePtr());
}
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
......@@ -181,6 +181,14 @@ class CudnnHolder {
}
}
void* WorkspacePtr() const {
if (workspace_ == nullptr) {
return nullptr;
} else {
return workspace_->ptr();
}
}
void ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= WorkspaceSize()) {
return;
......
......@@ -99,7 +99,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
py_buffer->shape = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * tensor.dims().size()));
for (size_t i = 0; i < tensor.dims().size(); ++i) {
for (int i = 0; i < tensor.dims().size(); ++i) {
py_buffer->shape[i] = tensor.dims()[i];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册