From 15ad7ee4f36f93edb979f2af75216a3df15afae8 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 23 Dec 2021 15:45:05 +0800 Subject: [PATCH] Support external stream. (#38373) * support external stream. * update * update * update --- .../fluid/inference/api/analysis_predictor.cc | 36 +++++++++++++++++++ .../fluid/inference/api/analysis_predictor.h | 6 ++++ .../inference/api/paddle_inference_api.h | 17 +++++++++ paddle/fluid/platform/device_context.cc | 20 +++++++++++ paddle/fluid/platform/device_context.h | 7 ++++ paddle/fluid/platform/stream/cuda_stream.cc | 16 ++++++++- paddle/fluid/platform/stream/cuda_stream.h | 5 +++ 7 files changed, 106 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 5d5719533e..4bfc24555d 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -24,6 +24,7 @@ #include #include +#include "paddle/fluid//platform/device/gpu/gpu_types.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h" @@ -1043,6 +1044,20 @@ bool AnalysisPredictor::ZeroCopyRun() { return true; } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { + if (stream != nullptr) { + paddle::platform::DeviceContextPool &pool = + paddle::platform::DeviceContextPool::Instance(); + auto gpu_place = BOOST_GET_CONST(paddle::platform::CUDAPlace, place_); + auto *dev_ctx = reinterpret_cast( + pool.Get(gpu_place)); + dev_ctx->SetThreadLocalStream(stream); + } + return ZeroCopyRun(); +} +#endif + void AnalysisPredictor::CollectShapeRangeInfo() { // if use gpu, sync first. if (config_.use_gpu()) { @@ -1567,4 +1582,25 @@ Predictor *PredictorPool::Retrive(size_t idx) { return preds_[idx - 1].get(); } } // namespace services + +namespace experimental { + +// Note: Can only be used under thread_local semantics. +bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p, + cudaStream_t stream) { +#ifdef PADDLE_WITH_CUDA + auto pred = dynamic_cast(p->predictor_.get()); + return pred->ExpRunWithExternalStream(stream); +#endif + return false; +} +bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p, + hipStream_t stream) { +#ifdef PADDLE_WITH_HIP + auto pred = dynamic_cast(p->predictor_.get()); + return pred->ExpRunWithExternalStream(stream); +#endif + return false; +} +} // namespace experimental } // namespace paddle_infer diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 9c36051757..70578bd201 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -25,6 +25,7 @@ #include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/string/printf.h" #ifdef PADDLE_WITH_TESTING @@ -172,6 +173,11 @@ class AnalysisPredictor : public PaddlePredictor { /// bool ZeroCopyRun() override; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // Note: Can only be used under thread_local semantics. + bool ExpRunWithExternalStream(const gpuStream_t stream); +#endif + /// /// \brief Create feed fetch variables /// diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 35b90bfa54..b2b9f2e407 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -41,11 +41,27 @@ limitations under the License. */ /// \since 2.0.0-beta /// +// forward declation +using cudaStream_t = struct CUstream_st*; +using hipStream_t = struct ihipStream_t*; + namespace paddle_infer { using PrecisionType = paddle::AnalysisConfig::Precision; using Config = paddle::AnalysisConfig; +class Predictor; +namespace experimental { +class PD_INFER_DECL InternalUtils { + public: + // Note: Can only be used under thread_local semantics. + static bool RunWithExternalStream(paddle_infer::Predictor* pred, + cudaStream_t stream); + static bool RunWithExternalStream(paddle_infer::Predictor* pred, + hipStream_t stream); +}; +} // namespace experimental + /// /// \class Predictor /// @@ -150,6 +166,7 @@ class PD_INFER_DECL Predictor { private: std::unique_ptr predictor_; + friend class paddle_infer::experimental::InternalUtils; }; /// diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 60442eb4a0..07508da703 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -488,6 +488,26 @@ CUDAContext::CUDAContext(const CUDAPlace& place, #endif } +void CUDAContext::SetStream(gpuStream_t stream) { + if (stream_->raw_stream() != stream) { + CUDADeviceGuard guard(place_.device); + DestoryCuDNNContext(); + DestoryCuBlasContext(); +#ifndef PADDLE_WITH_HIP + DestoryCuSolverContext(); +#endif + + stream_->SetStream(stream); + + InitEigenContext(); + InitCuBlasContext(); + InitCuDNNContext(); +#ifndef PADDLE_WITH_HIP + InitCuSolverContext(); +#endif + } +} + CUDAContext::~CUDAContext() { CUDADeviceGuard guard(place_.device); DestoryCuDNNContext(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 72fa525040..4b38e5ddf3 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -334,6 +334,8 @@ class CUDAContext { return old_stream_ptr; } + void SetStream(gpuStream_t stream); + const gpuStream_t& RawStream() { return stream_->raw_stream(); } #ifdef PADDLE_WITH_HIP @@ -616,6 +618,11 @@ class CUDADeviceContext : public DeviceContext { return thread_ctx_.at(this); } + // Note: Can only be used under thread_local semantics. + void SetThreadLocalStream(const gpuStream_t stream) { + thread_ctx_.at(this)->SetStream(stream); + } + private: CUDAPlace place_; std::shared_ptr default_ctx_; diff --git a/paddle/fluid/platform/stream/cuda_stream.cc b/paddle/fluid/platform/stream/cuda_stream.cc index dafb61fe0a..742d267b59 100644 --- a/paddle/fluid/platform/stream/cuda_stream.cc +++ b/paddle/fluid/platform/stream/cuda_stream.cc @@ -56,7 +56,7 @@ void CUDAStream::Destroy() { CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device); Wait(); WaitCallback(); - if (stream_) { + if (stream_ && owned_stream_) { #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_)); #else @@ -92,6 +92,20 @@ void CUDAStream::Wait() const { PADDLE_ENFORCE_GPU_SUCCESS(e_sync); } +// Note: Can only be used under thread_local semantics. +void CUDAStream::SetStream(gpuStream_t stream) { + if (owned_stream_ && stream_) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_)); +#endif + } + owned_stream_ = false; + stream_ = stream; + callback_manager_.reset(new StreamCallbackManager(stream_)); +} + CUDAStream* get_current_stream(int deviceId) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (deviceId == -1) { diff --git a/paddle/fluid/platform/stream/cuda_stream.h b/paddle/fluid/platform/stream/cuda_stream.h index 36f31c4667..0683cf4b04 100644 --- a/paddle/fluid/platform/stream/cuda_stream.h +++ b/paddle/fluid/platform/stream/cuda_stream.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/stream_callback_manager.h" @@ -130,8 +131,12 @@ class CUDAStream final { const Place& GetPlace() const { return place_; } + // Note: Can only be used under thread_local semantics. + void SetStream(gpuStream_t stream); + private: Place place_; + bool owned_stream_{true}; #ifdef PADDLE_WITH_HIP hipStream_t stream_{nullptr}; #else -- GitLab