未验证 提交 9aaae254 编写于 作者: Y Yuang Liu 提交者: GitHub

Fix dev ctx with cuda graph (#44109)

上级 a7c98ddb
...@@ -297,6 +297,10 @@ if(WITH_GPU) ...@@ -297,6 +297,10 @@ if(WITH_GPU)
device_context_test device_context_test
SRCS device_context_test.cu SRCS device_context_test.cu
DEPS device_context gpu_info) DEPS device_context gpu_info)
nv_test(
device_context_test_cuda_graph
SRCS device_context_test_cuda_graph.cu
DEPS device_context gpu_info cuda_graph_with_memory_pool)
nv_test( nv_test(
transform_test transform_test
SRCS transform_test.cu SRCS transform_test.cu
......
...@@ -26,7 +26,9 @@ namespace platform { ...@@ -26,7 +26,9 @@ namespace platform {
void BeginCUDAGraphCapture(platform::CUDAPlace place, void BeginCUDAGraphCapture(platform::CUDAPlace place,
cudaStreamCaptureMode mode, cudaStreamCaptureMode mode,
int64_t pool_id) { int64_t pool_id) {
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto* dev_ctx =
reinterpret_cast<platform::CUDADeviceContext*>(mutable_dev_ctx);
dev_ctx->cudnn_workspace_handle().ResetWorkspace(); dev_ctx->cudnn_workspace_handle().ResetWorkspace();
// After PR(#43206), cudnn related initializations will change to lazy mode. // After PR(#43206), cudnn related initializations will change to lazy mode.
...@@ -49,6 +51,9 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place, ...@@ -49,6 +51,9 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place,
pool_id = CUDAGraph::SetMemoryPoolID(pool_id); pool_id = CUDAGraph::SetMemoryPoolID(pool_id);
memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph( memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph(
pool_id); pool_id);
dev_ctx->SetCUDAGraphAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place)
.get());
if (old_value) { if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = true; FLAGS_use_stream_safe_cuda_allocator = true;
} }
...@@ -60,8 +65,11 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place, ...@@ -60,8 +65,11 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place,
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() { std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
auto place = CUDAGraph::CapturingPlace(); auto place = CUDAGraph::CapturingPlace();
auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto* dev_ctx =
reinterpret_cast<platform::CUDADeviceContext*>(mutable_dev_ctx);
dev_ctx->cudnn_workspace_handle().ResetWorkspace(); dev_ctx->cudnn_workspace_handle().ResetWorkspace();
dev_ctx->SetCUDAGraphAllocator(nullptr);
return CUDAGraph::EndCapture(); return CUDAGraph::EndCapture();
} }
#endif #endif
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/device_context.h"
TEST(Device, DeviceContextWithCUDAGraph) {
using paddle::platform::CUDADeviceContext;
using paddle::platform::CUDAPlace;
using paddle::platform::DeviceContext;
using paddle::platform::DeviceContextPool;
using paddle::platform::Place;
DeviceContextPool& pool = DeviceContextPool::Instance();
Place place = CUDAPlace(0);
auto* dev_ctx = pool.Get(place);
paddle::platform::BeginCUDAGraphCapture(
place, cudaStreamCaptureMode::cudaStreamCaptureModeThreadLocal, 0);
ASSERT_EQ(dev_ctx->IsCUDAGraphAllocatorValid(), true);
dev_ctx->GetCUDAGraphAllocator();
paddle::platform::EndCUDAGraphCapture();
ASSERT_EQ(dev_ctx->IsCUDAGraphAllocatorValid(), false);
}
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#endif
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
...@@ -58,6 +62,26 @@ struct DeviceContext::Impl { ...@@ -58,6 +62,26 @@ struct DeviceContext::Impl {
pinned_allocator_ = allocator; pinned_allocator_ = allocator;
} }
#ifdef PADDLE_WITH_CUDA
void SetCUDAGraphAllocator(const Allocator* allocator) {
// NOTE (Yuang): cuda graph allocator can be set to nullptr, so don't check
// validation of the allocator here
cuda_graph_allocator_ = allocator;
}
const Allocator& GetCUDAGraphAllocator() const {
PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
phi::errors::InvalidArgument(
"Required cuda_graph_allocator_ shall not be "
"nullptr, but received nullptr."));
return *cuda_graph_allocator_;
}
bool IsCUDAGraphAllocatorValid() const {
return cuda_graph_allocator_ != nullptr;
}
#endif
const Allocator& GetAllocator() const { const Allocator& GetAllocator() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
device_allocator_, device_allocator_,
...@@ -111,6 +135,17 @@ struct DeviceContext::Impl { ...@@ -111,6 +135,17 @@ struct DeviceContext::Impl {
auto* allocator = tensor->numel() == 0 auto* allocator = tensor->numel() == 0
? zero_allocator_ ? zero_allocator_
: (pinned ? pinned_allocator_ : device_allocator_); : (pinned ? pinned_allocator_ : device_allocator_);
#ifdef PADDLE_WITH_CUDA
bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned;
if (must_cuda_graph_allocator && paddle::platform::is_gpu_place(place) &&
paddle::platform::CUDAGraph::IsThisThreadCapturing()) {
PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
phi::errors::InvalidArgument(
"Required cuda_graph_allocator_ shall not be "
"nullptr, but received nullptr."));
allocator = cuda_graph_allocator_;
}
#endif
return tensor->AllocateFrom( return tensor->AllocateFrom(
const_cast<Allocator*>(allocator), dtype, requested_size); const_cast<Allocator*>(allocator), dtype, requested_size);
} }
...@@ -200,6 +235,9 @@ struct DeviceContext::Impl { ...@@ -200,6 +235,9 @@ struct DeviceContext::Impl {
const Allocator* host_allocator_{nullptr}; const Allocator* host_allocator_{nullptr};
const Allocator* zero_allocator_{nullptr}; const Allocator* zero_allocator_{nullptr};
const Allocator* pinned_allocator_{nullptr}; const Allocator* pinned_allocator_{nullptr};
#ifdef PADDLE_WITH_CUDA
const Allocator* cuda_graph_allocator_{nullptr};
#endif
Generator* device_generator_{nullptr}; Generator* device_generator_{nullptr};
Generator* host_generator_{nullptr}; Generator* host_generator_{nullptr};
}; };
...@@ -213,6 +251,11 @@ DeviceContext::DeviceContext(const DeviceContext& other) { ...@@ -213,6 +251,11 @@ DeviceContext::DeviceContext(const DeviceContext& other) {
impl_->SetPinnedAllocator(&other.GetPinnedAllocator()); impl_->SetPinnedAllocator(&other.GetPinnedAllocator());
impl_->SetHostGenerator(other.GetHostGenerator()); impl_->SetHostGenerator(other.GetHostGenerator());
impl_->SetGenerator(other.GetGenerator()); impl_->SetGenerator(other.GetGenerator());
#ifdef PADDLE_WITH_CUDA
if (other.IsCUDAGraphAllocatorValid()) {
impl_->SetCUDAGraphAllocator(&other.GetCUDAGraphAllocator());
}
#endif
} }
DeviceContext::DeviceContext(DeviceContext&& other) { DeviceContext::DeviceContext(DeviceContext&& other) {
...@@ -239,6 +282,20 @@ const Allocator& DeviceContext::GetHostAllocator() const { ...@@ -239,6 +282,20 @@ const Allocator& DeviceContext::GetHostAllocator() const {
return impl_->GetHostAllocator(); return impl_->GetHostAllocator();
} }
#ifdef PADDLE_WITH_CUDA
void DeviceContext::SetCUDAGraphAllocator(const Allocator* allocator) {
impl_->SetCUDAGraphAllocator(allocator);
}
const Allocator& DeviceContext::GetCUDAGraphAllocator() const {
return impl_->GetCUDAGraphAllocator();
}
bool DeviceContext::IsCUDAGraphAllocatorValid() const {
return impl_->IsCUDAGraphAllocatorValid();
}
#endif
void DeviceContext::SetZeroAllocator(const Allocator* allocator) { void DeviceContext::SetZeroAllocator(const Allocator* allocator) {
impl_->SetZeroAllocator(allocator); impl_->SetZeroAllocator(allocator);
} }
......
...@@ -106,6 +106,33 @@ class PADDLE_API DeviceContext { ...@@ -106,6 +106,33 @@ class PADDLE_API DeviceContext {
const Allocator& GetPinnedAllocator() const; const Allocator& GetPinnedAllocator() const;
#ifdef PADDLE_WITH_CUDA
/**
* @brief Set the CUDA graph Allocator object.
*
* @param allocator
*/
void SetCUDAGraphAllocator(const Allocator*);
/**
* @brief Get the const CUDA graph Allocator object.
*
* @return Allocator
*/
const Allocator& GetCUDAGraphAllocator() const;
/**
* @brief Test whether the CUDA graph allocator is valid
*
* This method should be called before calling GetCUDAGraphAllocator().
* Other unit can calls GetCUDAGraphAllocator() method,
* only when this method returns True!
*
* @return true if cuda_graph_allocator_ is valid, false otherwise
*/
bool IsCUDAGraphAllocatorValid() const;
#endif
/** /**
* @brief Allocate device memory for tensor. * @brief Allocate device memory for tensor.
*/ */
......
...@@ -236,6 +236,16 @@ class TestCUDAGraph(unittest.TestCase): ...@@ -236,6 +236,16 @@ class TestCUDAGraph(unittest.TestCase):
self.assertTrue(np.array_equal(actual_x, x.numpy())) self.assertTrue(np.array_equal(actual_x, x.numpy()))
self.assertTrue(np.array_equal(actual_y, y.numpy())) self.assertTrue(np.array_equal(actual_y, y.numpy()))
def test_dev_ctx_alloc(self):
if not can_use_cuda_graph():
return
x = paddle.to_tensor([2], dtype='float32')
graph = CUDAGraph()
graph.capture_begin()
y = paddle.cast(x, dtype='float16')
graph.capture_end()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册