diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 8cf53b9739992760240284a172f108d6753a5608..ce24f5a4d9c1868d2c35b5d5c56500ad1175ba79 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -17,8 +17,6 @@ cc_library(memory memcpy) if (WITH_GPU) - add_dependencies(malloc cuda_device_context_allocator_pool) - target_link_libraries(malloc cuda_device_context_allocator_pool) nv_test(malloc_test SRCS malloc_test.cu DEPS device_context malloc) diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index f00dda0b54843e7e2e50b151e91b8ee0664c3618..ffae6e648080ba32fafd38440e8ff8590437669a 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -14,12 +14,6 @@ endif() if (WITH_GPU) nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) - nv_library(cuda_device_context_allocation SRCS cuda_device_context_allocation.cc - DEPS allocator enforce place ${MKLDNN_CTX_DEPS}) - nv_library(cuda_device_context_allocator SRCS cuda_device_context_allocator.cc - DEPS allocator enforce place cuda_device_context_allocation ${MKLDNN_CTX_DEPS}) - nv_library(cuda_device_context_allocator_pool SRCS cuda_device_context_allocator_pool.cc - DEPS allocator enforce place cuda_device_context_allocation cuda_device_context_allocator ${MKLDNN_CTX_DEPS}) endif() cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocation.cc b/paddle/fluid/memory/allocation/cuda_device_context_allocation.cc deleted file mode 100644 index e361f71f4f75b14fa46e4bd5940ab100e7110cb2..0000000000000000000000000000000000000000 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocation.cc +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h" -#include -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace memory { -namespace allocation { - -CUDADeviceContextAllocation::CUDADeviceContextAllocation( - AllocationPtr allocation) - : Allocation(allocation->ptr(), allocation->size(), allocation->place()), - underlying_allocation_(std::move(allocation)) {} - -CUDADeviceContextAllocation::~CUDADeviceContextAllocation() { - PADDLE_ENFORCE_NOT_NULL( - dev_ctx_, "Didn't set device context for CUDADeviceContextAllocation"); - auto *p_allocation = underlying_allocation_.release(); - VLOG(4) << "Adding callback to delete CUDADeviceContextAllocation at " - << p_allocation; - dev_ctx_->AddStreamCallback([p_allocation] { - VLOG(4) << "Delete CUDADeviceContextAllocation at " << p_allocation; - AllocationDeleter()(p_allocation); - }); -} - -void CUDADeviceContextAllocation::SetCUDADeviceContext( - const platform::CUDADeviceContext *dev_ctx) { - dev_ctx_ = dev_ctx; -} - -} // namespace allocation -} // namespace memory -} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocation.h b/paddle/fluid/memory/allocation/cuda_device_context_allocation.h deleted file mode 100644 index 02011f88c1d9d80b24c7bd1c28747a85e4738711..0000000000000000000000000000000000000000 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocation.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace memory { -namespace allocation { - -/** - * CUDADeviceContextAllocation is a wrapper of the underbeneath allocation. - * CUDADeviceContextAllocation adds a CUDA stream callback for the underbeneath - * allocation so that CUDADeviceContextAllocation can be used in a CUDA stream - * which deletes allocation in the callback. - */ -class CUDADeviceContextAllocation : public Allocation { - public: - explicit CUDADeviceContextAllocation(AllocationPtr allocation); - ~CUDADeviceContextAllocation(); - void SetCUDADeviceContext(const platform::CUDADeviceContext *dev_ctx); - - private: - AllocationPtr underlying_allocation_; - const platform::CUDADeviceContext *dev_ctx_{nullptr}; -}; - -} // namespace allocation -} // namespace memory -} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator.cc b/paddle/fluid/memory/allocation/cuda_device_context_allocator.cc deleted file mode 100644 index bc9adc5caa261485fe383dc0ebd33f92beaebdff..0000000000000000000000000000000000000000 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocator.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" - -#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h" -#include "paddle/fluid/platform/cuda_device_guard.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace memory { -namespace allocation { - -CUDADeviceContextAllocator::CUDADeviceContextAllocator( - const platform::CUDAPlace place, cudaStream_t default_stream) - : place_(place), default_stream_(default_stream) { - platform::CUDADeviceGuard guard(place_.device); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventCreate(&event_, cudaEventDisableTiming), - "Create event failed in CUDADeviceContextAllocator"); -} - -CUDADeviceContextAllocator::~CUDADeviceContextAllocator() { - if (event_) { - platform::CUDADeviceGuard guard(place_.device); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventDestroy(event_), - "Destory event failed in CUDADeviceContextAllocator destroctor"); - } -} - -Allocation *CUDADeviceContextAllocator::AllocateImpl(size_t size) { - PADDLE_ENFORCE_NOT_NULL( - default_stream_, - "Didn't set default stream for CUDADeviceContextAllocator"); - platform::CUDADeviceGuard guard(place_.device); - auto allocation = - new CUDADeviceContextAllocation(memory::Alloc(place_, size)); - // Wait for the event on stream - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventRecord(event_, default_stream_), - "Failed to record event in CUDADeviceContextAllocator"); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamWaitEvent(default_stream_, event_, 0), - "Failed to wait event in CUDADeviceContextAllocator"); - return allocation; -} - -void CUDADeviceContextAllocator::FreeImpl(Allocation *allocation) { - delete allocation; -} - -} // namespace allocation -} // namespace memory -} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator.h b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h index 34bd1176db9cd1bfa57fd5dee401705539f974ad..1f8ad370bf2f7ed780e45f5775e8e599bdfbed71 100644 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h @@ -15,15 +15,58 @@ #pragma once #include +#include +#include +#include +#include #include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" namespace paddle { + +namespace platform { +class CUDADeviceContext; +} // namespace platform + namespace memory { namespace allocation { +/** + * CUDADeviceContextAllocation is a wrapper of the underbeneath allocation. + * CUDADeviceContextAllocation adds a CUDA stream callback for the underbeneath + * allocation so that CUDADeviceContextAllocation can be used in a CUDA stream + * which deletes allocation in the callback. + */ +class CUDADeviceContextAllocation : public Allocation { + public: + explicit CUDADeviceContextAllocation(AllocationPtr allocation) + : Allocation(allocation->ptr(), allocation->size(), allocation->place()), + underlying_allocation_(std::move(allocation)) {} + + ~CUDADeviceContextAllocation() { + PADDLE_ENFORCE_NOT_NULL( + dev_ctx_, "Didn't set device context for CUDADeviceContextAllocation"); + auto *p_allocation = underlying_allocation_.release(); + VLOG(4) << "Adding callback to delete CUDADeviceContextAllocation at " + << p_allocation; + dev_ctx_->AddStreamCallback([p_allocation] { + VLOG(4) << "Delete CUDADeviceContextAllocation at " << p_allocation; + AllocationDeleter()(p_allocation); + }); + } + + void SetCUDADeviceContext(const platform::CUDADeviceContext *dev_ctx) { + dev_ctx_ = dev_ctx; + } + + private: + AllocationPtr underlying_allocation_; + const platform::CUDADeviceContext *dev_ctx_{nullptr}; +}; + /** * CUDADeviceContextAllocator will allocate a CUDADeviceContextAllocation * after waiting for a self-created event on the default stream. It does so to @@ -33,12 +76,42 @@ namespace allocation { class CUDADeviceContextAllocator : public Allocator { public: explicit CUDADeviceContextAllocator(platform::CUDAPlace place, - cudaStream_t default_stream); - ~CUDADeviceContextAllocator(); + cudaStream_t default_stream) + : place_(place), default_stream_(default_stream) { + platform::CUDADeviceGuard guard(place_.device); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventCreate(&event_, cudaEventDisableTiming), + "Create event failed in CUDADeviceContextAllocator"); + } + + ~CUDADeviceContextAllocator() { + if (event_) { + platform::CUDADeviceGuard guard(place_.device); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventDestroy(event_), + "Destory event failed in CUDADeviceContextAllocator destroctor"); + } + } protected: - Allocation *AllocateImpl(size_t size) override; - void FreeImpl(Allocation *allocation) override; + Allocation *AllocateImpl(size_t size) override { + PADDLE_ENFORCE_NOT_NULL( + default_stream_, + "Didn't set default stream for CUDADeviceContextAllocator"); + platform::CUDADeviceGuard guard(place_.device); + auto allocation = + new CUDADeviceContextAllocation(memory::Alloc(place_, size)); + // Wait for the event on stream + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventRecord(event_, default_stream_), + "Failed to record event in CUDADeviceContextAllocator"); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(default_stream_, event_, 0), + "Failed to wait event in CUDADeviceContextAllocator"); + return allocation; + } + + void FreeImpl(Allocation *allocation) override { delete allocation; } private: platform::CUDAPlace place_; @@ -46,6 +119,49 @@ class CUDADeviceContextAllocator : public Allocator { cudaStream_t default_stream_{nullptr}; }; +/** + * CUDADeviceContextAllocatorPool is a singletion stores mapping from + * CUDAPlace(s) to std::shared_ptr. When a + * CUDADeviceContext's compute stream isn't default stream, it can call this + * class to allocate GPU memory which will be released by a callback after + * stream execution. + */ +class CUDADeviceContextAllocatorPool { + public: + static CUDADeviceContextAllocatorPool &Instance() { + static CUDADeviceContextAllocatorPool pool; + return pool; + } + + AllocationPtr Alloc(const platform::CUDADeviceContext &dev_ctx, size_t size) { + auto iter = + allocators_.find(boost::get(dev_ctx.GetPlace())); + PADDLE_ENFORCE_EQ(iter != allocators_.end(), true, + "CUDADeviceContextAllocatorPool initialization error"); + auto &allocator = iter->second; + AllocationPtr allocation = allocator->Allocate(size); + static_cast(allocation.get()) + ->SetCUDADeviceContext(&dev_ctx); + return allocation; + } + + private: + CUDADeviceContextAllocatorPool() { + std::vector devices = platform::GetSelectedDevices(); + for (int i : devices) { + auto place = platform::CUDAPlace(i); + auto compute_stream = + platform::DeviceContextPool::Instance().GetByPlace(place)->stream(); + auto allocator = std::shared_ptr( + new CUDADeviceContextAllocator(place, compute_stream)); + allocators_.insert(make_pair(place, allocator)); + } + } + + std::map> + allocators_; +}; + } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.cc b/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.cc deleted file mode 100644 index e0b6825944cac826e9e0571c633e8a98250c7570..0000000000000000000000000000000000000000 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h" - -#include -#include -#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h" -#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace memory { -namespace allocation { - -CUDADeviceContextAllocatorPool &CUDADeviceContextAllocatorPool::Instance() { - static CUDADeviceContextAllocatorPool pool; - return pool; -} - -AllocationPtr CUDADeviceContextAllocatorPool::Alloc( - const platform::CUDADeviceContext &dev_ctx, size_t size) { - auto iter = - allocators_.find(boost::get(dev_ctx.GetPlace())); - PADDLE_ENFORCE_EQ(iter != allocators_.end(), true, - "CUDADeviceContextAllocatorPool initialization error"); - auto &allocator = iter->second; - AllocationPtr allocation = allocator->Allocate(size); - static_cast(allocation.get()) - ->SetCUDADeviceContext(&dev_ctx); - return allocation; -} - -CUDADeviceContextAllocatorPool::CUDADeviceContextAllocatorPool() { - std::vector devices = platform::GetSelectedDevices(); - for (int i : devices) { - auto place = platform::CUDAPlace(i); - auto compute_stream = - platform::DeviceContextPool::Instance().GetByPlace(place)->stream(); - auto allocator = std::shared_ptr( - new CUDADeviceContextAllocator(place, compute_stream)); - allocators_.insert(make_pair(place, allocator)); - } -} - -} // namespace allocation -} // namespace memory -} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h b/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h deleted file mode 100644 index b423f226d94492c4ce0d8c8752b0cca2b1745bb3..0000000000000000000000000000000000000000 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace memory { -namespace allocation { - -/** - * CUDADeviceContextAllocatorPool is a singletion stores mapping from - * CUDAPlace(s) to std::shared_ptr. When a - * CUDADeviceContext's compute stream isn't default stream, it can call this - * class to allocate GPU memory which will be released by a callback after - * stream execution. - */ -class CUDADeviceContextAllocatorPool { - public: - static CUDADeviceContextAllocatorPool &Instance(); - - AllocationPtr Alloc(const platform::CUDADeviceContext &dev_ctx, size_t size); - - private: - CUDADeviceContextAllocatorPool(); - std::map> - allocators_; -}; - -} // namespace allocation -} // namespace memory -} // namespace paddle diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index f1a75f2add384910b706ac69c2d001ad1e659359..e01f030585a8330a2e9bcc2bc2a662f00f5cde1c 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -17,10 +17,6 @@ limitations under the License. */ #include #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h" -#endif -#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -35,26 +31,5 @@ AllocationPtr Alloc(const platform::Place &place, size_t size) { return allocation::AllocatorFacade::Instance().Alloc(place, size); } -AllocationPtr Alloc(const platform::DeviceContext &dev_ctx, size_t size) { - auto place = dev_ctx.GetPlace(); -#ifdef PADDLE_WITH_CUDA - if (size == 0 || !platform::is_gpu_place(place)) { - return Alloc(place, size); - } - auto *default_dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - auto &desired_dev_ctx = - static_cast(dev_ctx); - if (default_dev_ctx->stream() == desired_dev_ctx.stream()) { - return Alloc(place, size); - } else { - return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc( - desired_dev_ctx, size); - } -#else - return Alloc(place, size); -#endif -} - } // namespace memory } // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 3e0f2490ff469a28e6615bddeab5a7282f40c610..a84f521f589ab680513852fdb83e593ba4946fe5 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -70,7 +70,7 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies -cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} +cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} ${dgc_deps}) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index ce0d73f520a711d1cd7d77358425a6bc2ab3de60..3166593365404e98fad0e91a7d7b5cd7176cd9ed 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -18,11 +18,39 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/framework/rw_lock.h" +#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" #include "paddle/fluid/platform/cuda_device_guard.h" #endif #include "glog/logging.h" +namespace paddle { +namespace memory { + +AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) { + auto place = dev_ctx.GetPlace(); +#ifdef PADDLE_WITH_CUDA + if (size == 0 || !platform::is_gpu_place(place)) { + return Alloc(place, size); + } + auto* default_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto& desired_dev_ctx = + static_cast(dev_ctx); + if (default_dev_ctx->stream() == desired_dev_ctx.stream()) { + return Alloc(place, size); + } else { + return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc( + desired_dev_ctx, size); + } +#else + return Alloc(place, size); +#endif +} + +} // namespace memory +} // namespace paddle + namespace paddle { namespace platform { @@ -174,6 +202,15 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable std::unordered_map allocations_; }; +void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) { + if (required_workspace_bytes <= WorkspaceSize()) { + return; + } + // reset allocation first before re-allocate to save memory + allocation_.reset(); + allocation_ = memory::Alloc(device_context_, required_workspace_bytes); +} + CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { CUDADeviceGuard guard(place_.device); compute_capability_ = GetCUDAComputeCapability(place_.device); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index cbb700fb35a648702670d31db5d339397b2c9f86..3504f62b7bdaa523deb2ae2074cf0d22cfe93851 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -220,14 +220,7 @@ class CudnnWorkspaceHandle { ResetWorkspace(); } - inline void ReallocWorkspace(size_t required_workspace_bytes) { - if (required_workspace_bytes <= WorkspaceSize()) { - return; - } - // reset allocation first before re-allocate to save memory - allocation_.reset(); - allocation_ = memory::Alloc(device_context_, required_workspace_bytes); - } + void ReallocWorkspace(size_t required_workspace_bytes); inline void ResetWorkspace() { allocation_ = nullptr; }