未验证 提交 eb97f4f0 编写于 作者: W Wilber 提交者: GitHub

[Inference] Update switch stream logical. (#53589)

上级 05d3fc81
......@@ -94,6 +94,61 @@
#endif
namespace paddle {
namespace {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void UpdatePrivateDeviceContext(InferGPUContext *gpu_context,
GPUContextResource *gpu_resource,
Place place_) {
gpu_context->SetAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place_, gpu_resource->GetStream())
.get());
gpu_context->SetPinnedAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
gpu_context->SetHostAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CPUPlace())
.get());
gpu_context->SetZeroAllocator(memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(place_)
.get());
gpu_context->SetHostZeroAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(platform::CPUPlace())
.get());
gpu_context->SetGenerator(
phi::DefaultCUDAGenerator(place_.GetDeviceId()).get());
gpu_context->SetHostGenerator(phi::DefaultCPUGenerator().get());
gpu_context->SetStream(gpu_resource->GetStream());
gpu_context->SetBlasHandle(gpu_resource->GetBlasHandleCreator());
gpu_context->SetBlasTensorCoreHandle(
gpu_resource->GetBlasTensorCoreHandleCreator());
gpu_context->SetBlasTF32Handle(
gpu_resource->GetBlasTF32TensorCoreHandleCreator());
gpu_context->SetDnnHandle(gpu_resource->GetDnnHandleCreator());
gpu_context->SetSolverHandle(gpu_resource->GetSolverDnHandleCreator());
gpu_context->SetSparseHandle(gpu_resource->GetSparseHandleCreator());
gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDevice());
gpu_context->SetComputeCapability(gpu_resource->GetGpuComputeCapability());
gpu_context->SetMaxThreadsPerBlock(gpu_resource->GetGpuMaxThreadsPerBlock());
gpu_context->SetMaxThreadsPerMultiProcessor(
gpu_resource->GetGpuMaxThreadsPerMp());
gpu_context->SetMaxGridDimSize(gpu_resource->GetGpuMaxGridDimSize());
gpu_context->SetMultiProcessors(gpu_resource->GetGPUMultiProcessors());
gpu_context->SetDriverVersion(gpu_resource->GetGpuDriverVersion());
gpu_context->SetRuntimeVersion(gpu_resource->GetGpuRuntimeVersion());
VLOG(1) << "thread id is " << std::this_thread::get_id() << ", stream id is "
<< reinterpret_cast<void *>(gpu_resource->GetStream())
<< ", allotor ptr is "
<< reinterpret_cast<void *>(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place_, gpu_resource->GetStream())
.get());
}
#endif
} // namespace
using inference::Singleton;
#ifdef PADDLE_WITH_TENSORRT
......@@ -451,60 +506,7 @@ void AnalysisPredictor::InitDeviceContexts() {
auto *gpu_resource =
ResourceManager::Instance().GetGPUResource(predictor_stream_);
auto *gpu_context = new InferGPUContext(place_);
gpu_context->SetAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place_, gpu_resource->GetStream())
.get());
gpu_context->SetPinnedAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
gpu_context->SetHostAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CPUPlace())
.get());
gpu_context->SetZeroAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(place_)
.get());
gpu_context->SetHostZeroAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(platform::CPUPlace())
.get());
gpu_context->SetGenerator(
phi::DefaultCUDAGenerator(place_.GetDeviceId()).get());
gpu_context->SetHostGenerator(phi::DefaultCPUGenerator().get());
gpu_context->SetStream(gpu_resource->GetStream());
gpu_context->SetBlasHandle(gpu_resource->GetBlasHandleCreator());
gpu_context->SetBlasTensorCoreHandle(
gpu_resource->GetBlasTensorCoreHandleCreator());
gpu_context->SetBlasTF32Handle(
gpu_resource->GetBlasTF32TensorCoreHandleCreator());
gpu_context->SetDnnHandle(gpu_resource->GetDnnHandleCreator());
gpu_context->SetSolverHandle(
gpu_resource->GetSolverDnHandleCreator());
gpu_context->SetSparseHandle(gpu_resource->GetSparseHandleCreator());
gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDeviceCreator());
gpu_context->SetComputeCapability(
gpu_resource->GetGpuComputeCapability());
gpu_context->SetMaxThreadsPerBlock(
gpu_resource->GetGpuMaxThreadsPerBlock());
gpu_context->SetMaxThreadsPerMultiProcessor(
gpu_resource->GetGpuMaxThreadsPerMp());
gpu_context->SetMaxGridDimSize(gpu_resource->GetGpuMaxGridDimSize());
gpu_context->SetMultiProcessors(
gpu_resource->GetGPUMultiProcessors());
gpu_context->SetDriverVersion(gpu_resource->GetGpuDriverVersion());
gpu_context->SetRuntimeVersion(gpu_resource->GetGpuRuntimeVersion());
VLOG(1) << "thread id is " << std::this_thread::get_id()
<< ", stream id is "
<< reinterpret_cast<void *>(gpu_resource->GetStream())
<< ", allotor ptr is "
<< reinterpret_cast<void *>(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place_, gpu_resource->GetStream())
.get());
UpdatePrivateDeviceContext(gpu_context, gpu_resource, place_);
return std::unique_ptr<phi::DeviceContext>(gpu_context);
}));
}
......@@ -2083,17 +2085,27 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
#else
cudaStreamSynchronize(static_cast<gpuStream_t>(predictor_stream_));
#endif
ResourceManager::Instance().GpuResourceReBindStream(predictor_stream_,
ResourceManager::Instance().GpuResourceSwitchStream(predictor_stream_,
stream);
predictor_stream_ = stream;
auto *dev_ctxs = reinterpret_cast<const std::map<
phi::Place,
std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
this->GetDeviceContexts());
auto *dev_ctx =
static_cast<InferGPUContext *>(dev_ctxs->at(place_).get().get());
dev_ctx->SetStream(stream);
auto *dev_ctxs = const_cast<
std::map<phi::Place,
std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
reinterpret_cast<const std::map<
phi::Place,
std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
this->GetDeviceContexts()));
dev_ctxs->erase(place_);
dev_ctxs->emplace(
place_, std::async(std::launch::deferred, [=] {
auto *gpu_resource =
ResourceManager::Instance().GetGPUResource(predictor_stream_);
auto *gpu_context = new InferGPUContext(place_);
UpdatePrivateDeviceContext(gpu_context, gpu_resource, place_);
return std::unique_ptr<phi::DeviceContext>(gpu_context);
}));
}
return ZeroCopyRun();
......
......@@ -154,6 +154,7 @@ void GPUContextResource::InitGPUResource(void* stream) {
}
InitGpuProperties();
InitGpuEigenDevice();
}
void GPUContextResource::DestroyGPUResource() {
......@@ -361,90 +362,6 @@ std::array<int, 3> GPUContextResource::GetGpuMaxGridDimSize() const {
return max_grid_dim_size_;
}
void GPUContextResource::ReBindStream(gpuStream_t stream) {
owned_stream_ = false;
stream_ = stream;
}
void GPUContextResource::ReBindDnnHandle(gpuStream_t stream) const {
if (dnn_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenSetStream(dnn_handle_, stream));
#else
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cudnnSetStream(dnn_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindBlasHandle(gpuStream_t stream) const {
if (blas_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_set_stream(blas_handle_, stream));
#else
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindBlasTensorCoreHandle(gpuStream_t stream) const {
if (blas_tensor_core_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_set_stream(blas_tensor_core_handle_, stream));
#else
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tensor_core_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindBlasTF32Handle(gpuStream_t stream) const {
if (blas_tf32_tensor_core_handle_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::rocblas_set_stream(
blas_tf32_tensor_core_handle_, stream));
#else
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cublasSetStream(blas_tf32_tensor_core_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindSolverDnHandle(gpuStream_t stream) const {
if (solver_handle_) {
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cusolverDnSetStream(solver_handle_, stream));
#endif
}
}
void GPUContextResource::ReBindSparseHandle(gpuStream_t stream) const {
if (sparse_handle_) {
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 11000
PADDLE_RETRY_CUDA_SUCCESS(
phi::dynload::cusparseSetStream(sparse_handle_, stream));
#endif
#endif
}
}
void GPUContextResource::ReBindEigenDevice(gpuStream_t stream,
GPUPlace place) const {
if (eigen_stream_) {
auto* allocator = paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place_)
.get();
eigen_stream_->Reinitialize(stream, allocator, place);
}
}
#endif
void ResourceManager::InitCPUResource() {
......@@ -486,24 +403,16 @@ void ResourceManager::DestroyGPUResource(void* stream) {
}
void ResourceManager::Decrease(void* stream) {
PADDLE_ENFORCE_EQ(ref_count_.count(stream),
true,
platform::errors::InvalidArgument(
"The stream[%p] not found in ref_count.", stream));
if (ref_count_.count(stream) == 0) return;
--ref_count_[stream];
if (ref_count_[stream] == 0) {
ref_count_.erase(stream);
gpu_resources_.erase(stream);
if (gpu_resources_.count(stream) > 0) gpu_resources_.erase(stream);
}
}
void ResourceManager::Increase(void* stream) {
PADDLE_ENFORCE_EQ(ref_count_.count(stream),
true,
platform::errors::InvalidArgument(
"The stream[%p] not found in ref_count.", stream));
++ref_count_[stream];
}
void ResourceManager::Increase(void* stream) { ++ref_count_[stream]; }
GPUContextResource* ResourceManager::GetGPUResource(void* stream) const {
PADDLE_ENFORCE_EQ(gpu_resources_.count(stream),
......@@ -513,33 +422,29 @@ GPUContextResource* ResourceManager::GetGPUResource(void* stream) const {
return gpu_resources_.at(stream).get();
}
void ResourceManager::GpuResourceReBindStream(void* old_stream,
void ResourceManager::GpuResourceSwitchStream(void* old_stream,
void* new_stream) {
// NOTE: add lock to support stream rebind in multi-thread
std::lock_guard<std::mutex> lock_gurad(gpu_mutex_);
if (old_stream == new_stream) return;
PADDLE_ENFORCE_EQ(
gpu_resources_.count(old_stream),
true,
platform::errors::InvalidArgument(
"The stream[%p] not found in gpu_resources.", old_stream));
auto gpu_resource = std::move(gpu_resources_.at(old_stream));
DestroyGPUResource(old_stream);
PADDLE_ENFORCE_EQ(
ref_count_.count(old_stream),
0,
platform::errors::Fatal("gpu resources rebind stream failed."));
gpu_resource->ReBindStream(static_cast<gpuStream_t>(new_stream));
gpu_resource->ReBindDnnHandle(static_cast<gpuStream_t>(new_stream));
gpu_resource->ReBindBlasHandle(static_cast<gpuStream_t>(new_stream));
gpu_resource->ReBindBlasTensorCoreHandle(
static_cast<gpuStream_t>(new_stream));
gpu_resource->ReBindBlasTF32Handle(static_cast<gpuStream_t>(new_stream));
gpu_resource->ReBindSolverDnHandle(static_cast<gpuStream_t>(new_stream));
gpu_resource->ReBindSparseHandle(static_cast<gpuStream_t>(new_stream));
gpu_resource->ReBindEigenDevice(static_cast<gpuStream_t>(new_stream),
gpu_resource->Place());
ref_count_[new_stream]++;
gpu_resources_.emplace(new_stream, std::move(gpu_resource));
// NOTE: stream may be used by multiple predictor, skip resource
// operation if resource of new_stream is already exists
bool new_stream_existed = gpu_resources_.count(new_stream) > 0;
if (!new_stream_existed) {
auto place = gpu_resources_.at(old_stream)->Place();
std::unique_ptr<GPUContextResource> resource{
new GPUContextResource(place, new_stream)};
gpu_resources_.emplace(new_stream, std::move(resource));
}
Decrease(old_stream);
Increase(new_stream);
}
int ResourceManager::RefCount(void* stream) const {
......
......@@ -82,16 +82,6 @@ class GPUContextResource {
int GetGpuMaxThreadsPerBlock() const;
std::array<int, 3> GetGpuMaxGridDimSize() const;
// If stream changes, we need to rebind all handle to new stream.
void ReBindStream(gpuStream_t stream);
void ReBindDnnHandle(gpuStream_t stream) const;
void ReBindBlasHandle(gpuStream_t stream) const;
void ReBindBlasTensorCoreHandle(gpuStream_t stream) const;
void ReBindBlasTF32Handle(gpuStream_t stream) const;
void ReBindSolverDnHandle(gpuStream_t stream) const;
void ReBindSparseHandle(gpuStream_t stream) const;
void ReBindEigenDevice(gpuStream_t stream, GPUPlace place) const;
private:
void InitGPUResource(void* stream);
void DestroyGPUResource();
......@@ -186,7 +176,7 @@ class ResourceManager {
void DestroyGPUResource(void* stream);
GPUContextResource* GetGPUResource(void* stream) const;
int RefCount(void* stream) const;
void GpuResourceReBindStream(void* old_stream, void* new_stream);
void GpuResourceSwitchStream(void* old_stream, void* new_stream);
private:
void Decrease(void* stream);
......
......@@ -1028,6 +1028,18 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST)
target_link_libraries(test_analyzer_capi_exp_xpu paddle_inference_c)
endif()
#TODO(inference): windows encounter a SEH error, we need to fix it.
if(NOT WIN32)
inference_analysis_test(
trt_rebind_stream_test
SRCS
trt_rebind_stream_test.cc
EXTRA_DEPS
paddle_inference_shared
ARGS
--infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
endif()
set(TRT_MODEL_QUANT_RESNET_DIR
"${INFERENCE_DEMO_INSTALL_DIR}/small_quant_model")
if(NOT EXISTS ${INFERENCE_DEMO_INSTALL_DIR}/small_quant_model.tgz)
......@@ -1378,6 +1390,10 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST)
endif()
if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 240)
if(NOT WIN32)
set_tests_properties(trt_rebind_stream_test
PROPERTIES TIMEOUT 360 LABELS "RUN_TYPE=EXCLUSIVE")
endif()
if(WITH_MKLDNN)
set_tests_properties(test_analyzer_bfloat16_resnet50 PROPERTIES TIMEOUT
120)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <thread>
#include "gflags/gflags.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "test/cpp/inference/api/tester_helper.h"
namespace paddle {
namespace inference {
// TODO(inference): This case failed in windows with a SEH error, we need to fix
// it.
TEST(ReBindStream_single, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.EnableTensorRtEngine();
cudaStream_t stream1, stream2, stream3;
cudaStreamCreate(&stream1);
cudaStreamCreate(&stream2);
cudaStreamCreate(&stream3);
config.SetExecStream(stream1);
auto predictor = paddle_infer::CreatePredictor(config);
auto x_t = predictor->GetInputHandle("x");
x_t->Reshape({1, 3, 224, 224});
float x_data[3 * 224 * 224] = {0};
x_t->CopyFromCpu(x_data);
ASSERT_TRUE(predictor->Run());
cudaDeviceSynchronize();
ASSERT_TRUE(paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor.get(), stream2));
cudaDeviceSynchronize();
ASSERT_TRUE(paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor.get(), stream3));
cudaDeviceSynchronize();
}
TEST(ReBindStream_multi, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
AnalysisConfig config1;
config1.EnableUseGpu(100, 0);
config1.SetModel(model_dir);
config1.EnableTensorRtEngine();
AnalysisConfig config2;
config2.EnableUseGpu(100, 0);
config2.EnableTensorRtEngine();
config2.SetModel(model_dir);
cudaStream_t stream1, stream2, stream3;
cudaStreamCreate(&stream1);
cudaStreamCreate(&stream2);
cudaStreamCreate(&stream3);
config1.SetExecStream(stream1);
config2.SetExecStream(stream1);
auto predictor1 = paddle_infer::CreatePredictor(config1);
auto predictor2 = paddle_infer::CreatePredictor(config2);
std::vector<float> x1(3 * 224 * 224, 1.0);
auto x_t1 = predictor1->GetInputHandle("x");
x_t1->Reshape({1, 3, 224, 224});
x_t1->CopyFromCpu(x1.data());
std::vector<float> x2(3 * 224 * 224, 2.0);
auto x_t2 = predictor2->GetInputHandle("x");
x_t2->Reshape({1, 3, 224, 224});
x_t2->CopyFromCpu(x2.data());
ASSERT_TRUE(predictor1->Run());
cudaStreamSynchronize(stream1);
ASSERT_TRUE(predictor2->Run());
cudaStreamSynchronize(stream1);
ASSERT_TRUE(paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor1.get(), stream2));
cudaDeviceSynchronize();
ASSERT_TRUE(paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor2.get(), stream2));
cudaDeviceSynchronize();
ASSERT_TRUE(paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor1.get(), stream3));
cudaStreamSynchronize(stream3);
ASSERT_TRUE(paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor2.get(), stream3));
cudaStreamSynchronize(stream3);
}
TEST(SwitchStream_multi, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
AnalysisConfig config1;
config1.EnableUseGpu(100, 0);
config1.SetModel(model_dir);
AnalysisConfig config2;
config2.EnableUseGpu(100, 0);
config2.SetModel(model_dir);
AnalysisConfig config3;
config3.EnableUseGpu(100, 0);
config3.SetModel(model_dir);
// config1.EnableTensorRtEngine();
// config2.EnableTensorRtEngine();
// config3.EnableTensorRtEngine();
cudaStream_t stream1, stream2, stream3;
cudaStreamCreate(&stream1);
cudaStreamCreate(&stream2);
cudaStreamCreate(&stream3);
config1.SetExecStream(stream1);
config2.SetExecStream(stream1);
config3.SetExecStream(stream1);
auto predictor1 = paddle_infer::CreatePredictor(config1);
auto predictor2 = paddle_infer::CreatePredictor(config2);
auto predictor3 = paddle_infer::CreatePredictor(config3);
std::vector<float> x1(3 * 224 * 224, 1.0);
auto x_t1 = predictor1->GetInputHandle("x");
x_t1->Reshape({1, 3, 224, 224});
x_t1->CopyFromCpu(x1.data());
std::vector<float> x2(3 * 224 * 224, 2.0);
auto x_t2 = predictor2->GetInputHandle("x");
x_t2->Reshape({1, 3, 224, 224});
x_t2->CopyFromCpu(x2.data());
std::vector<float> x3(3 * 224 * 224, 2.5);
auto x_t3 = predictor3->GetInputHandle("x");
x_t3->Reshape({1, 3, 224, 224});
x_t3->CopyFromCpu(x3.data());
// TODO(wilber): fix.
// NOTE: Must run once on master thread, but why?
// if remove the code, the unit test fail.
ASSERT_TRUE(predictor1->Run());
cudaStreamSynchronize(stream1);
ASSERT_TRUE(predictor2->Run());
cudaStreamSynchronize(stream1);
ASSERT_TRUE(predictor3->Run());
cudaStreamSynchronize(stream1);
auto Run = [&](paddle_infer::Predictor* p,
std::vector<cudaStream_t> streams) {
for (auto s : streams) {
paddle_infer::experimental::InternalUtils::RunWithExternalStream(p, s);
}
};
std::thread p1(Run,
predictor1.get(),
std::vector<cudaStream_t>{
stream1, stream2, stream3, stream3, stream2, stream2});
std::thread p2(Run,
predictor2.get(),
std::vector<cudaStream_t>{
stream1, stream3, stream1, stream2, stream1, stream3});
std::thread p3(Run,
predictor3.get(),
std::vector<cudaStream_t>{
stream1, stream1, stream2, stream3, stream3, stream2});
p1.join();
p2.join();
p3.join();
cudaDeviceSynchronize();
}
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册