diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 5d5719533e7a745e67949152ff2a83c1b06f2d06..4bfc24555d681ed163b4797c66b0ae35597419f1 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -24,6 +24,7 @@ #include #include +#include "paddle/fluid//platform/device/gpu/gpu_types.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h" @@ -1043,6 +1044,20 @@ bool AnalysisPredictor::ZeroCopyRun() { return true; } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { + if (stream != nullptr) { + paddle::platform::DeviceContextPool &pool = + paddle::platform::DeviceContextPool::Instance(); + auto gpu_place = BOOST_GET_CONST(paddle::platform::CUDAPlace, place_); + auto *dev_ctx = reinterpret_cast( + pool.Get(gpu_place)); + dev_ctx->SetThreadLocalStream(stream); + } + return ZeroCopyRun(); +} +#endif + void AnalysisPredictor::CollectShapeRangeInfo() { // if use gpu, sync first. if (config_.use_gpu()) { @@ -1567,4 +1582,25 @@ Predictor *PredictorPool::Retrive(size_t idx) { return preds_[idx - 1].get(); } } // namespace services + +namespace experimental { + +// Note: Can only be used under thread_local semantics. +bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p, + cudaStream_t stream) { +#ifdef PADDLE_WITH_CUDA + auto pred = dynamic_cast(p->predictor_.get()); + return pred->ExpRunWithExternalStream(stream); +#endif + return false; +} +bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p, + hipStream_t stream) { +#ifdef PADDLE_WITH_HIP + auto pred = dynamic_cast(p->predictor_.get()); + return pred->ExpRunWithExternalStream(stream); +#endif + return false; +} +} // namespace experimental } // namespace paddle_infer diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 9c36051757527c1357454a2ec7e98d5ce56147b1..70578bd201979b79e95e359d0da045dc956d2e09 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -25,6 +25,7 @@ #include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/string/printf.h" #ifdef PADDLE_WITH_TESTING @@ -172,6 +173,11 @@ class AnalysisPredictor : public PaddlePredictor { /// bool ZeroCopyRun() override; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // Note: Can only be used under thread_local semantics. + bool ExpRunWithExternalStream(const gpuStream_t stream); +#endif + /// /// \brief Create feed fetch variables /// diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 35b90bfa54f73cbc6d9dac97bfd55dadf11a1dca..b2b9f2e40747855f211ed79bf053afbca41f55ee 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -41,11 +41,27 @@ limitations under the License. */ /// \since 2.0.0-beta /// +// forward declation +using cudaStream_t = struct CUstream_st*; +using hipStream_t = struct ihipStream_t*; + namespace paddle_infer { using PrecisionType = paddle::AnalysisConfig::Precision; using Config = paddle::AnalysisConfig; +class Predictor; +namespace experimental { +class PD_INFER_DECL InternalUtils { + public: + // Note: Can only be used under thread_local semantics. + static bool RunWithExternalStream(paddle_infer::Predictor* pred, + cudaStream_t stream); + static bool RunWithExternalStream(paddle_infer::Predictor* pred, + hipStream_t stream); +}; +} // namespace experimental + /// /// \class Predictor /// @@ -150,6 +166,7 @@ class PD_INFER_DECL Predictor { private: std::unique_ptr predictor_; + friend class paddle_infer::experimental::InternalUtils; }; /// diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 60442eb4a0e9e25eea40ac1b937693de6da5820a..07508da703d35bee3cc617d67de0e1ccc7846452 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -488,6 +488,26 @@ CUDAContext::CUDAContext(const CUDAPlace& place, #endif } +void CUDAContext::SetStream(gpuStream_t stream) { + if (stream_->raw_stream() != stream) { + CUDADeviceGuard guard(place_.device); + DestoryCuDNNContext(); + DestoryCuBlasContext(); +#ifndef PADDLE_WITH_HIP + DestoryCuSolverContext(); +#endif + + stream_->SetStream(stream); + + InitEigenContext(); + InitCuBlasContext(); + InitCuDNNContext(); +#ifndef PADDLE_WITH_HIP + InitCuSolverContext(); +#endif + } +} + CUDAContext::~CUDAContext() { CUDADeviceGuard guard(place_.device); DestoryCuDNNContext(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 72fa525040b5f71d1ecf50fb391d8a4a6b1a8ab0..4b38e5ddf307196a8cadf41b944fe1abe4e2de94 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -334,6 +334,8 @@ class CUDAContext { return old_stream_ptr; } + void SetStream(gpuStream_t stream); + const gpuStream_t& RawStream() { return stream_->raw_stream(); } #ifdef PADDLE_WITH_HIP @@ -616,6 +618,11 @@ class CUDADeviceContext : public DeviceContext { return thread_ctx_.at(this); } + // Note: Can only be used under thread_local semantics. + void SetThreadLocalStream(const gpuStream_t stream) { + thread_ctx_.at(this)->SetStream(stream); + } + private: CUDAPlace place_; std::shared_ptr default_ctx_; diff --git a/paddle/fluid/platform/stream/cuda_stream.cc b/paddle/fluid/platform/stream/cuda_stream.cc index dafb61fe0aaf4379099522e8f270166b56349eae..742d267b5954353da2362fe9201426b23ead63a9 100644 --- a/paddle/fluid/platform/stream/cuda_stream.cc +++ b/paddle/fluid/platform/stream/cuda_stream.cc @@ -56,7 +56,7 @@ void CUDAStream::Destroy() { CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device); Wait(); WaitCallback(); - if (stream_) { + if (stream_ && owned_stream_) { #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_)); #else @@ -92,6 +92,20 @@ void CUDAStream::Wait() const { PADDLE_ENFORCE_GPU_SUCCESS(e_sync); } +// Note: Can only be used under thread_local semantics. +void CUDAStream::SetStream(gpuStream_t stream) { + if (owned_stream_ && stream_) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_)); +#endif + } + owned_stream_ = false; + stream_ = stream; + callback_manager_.reset(new StreamCallbackManager(stream_)); +} + CUDAStream* get_current_stream(int deviceId) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (deviceId == -1) { diff --git a/paddle/fluid/platform/stream/cuda_stream.h b/paddle/fluid/platform/stream/cuda_stream.h index 36f31c46673b2f0a6dcec5c01dcc12099798a6d5..0683cf4b0424ed6b3771a253de6c5c18589f640f 100644 --- a/paddle/fluid/platform/stream/cuda_stream.h +++ b/paddle/fluid/platform/stream/cuda_stream.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/stream_callback_manager.h" @@ -130,8 +131,12 @@ class CUDAStream final { const Place& GetPlace() const { return place_; } + // Note: Can only be used under thread_local semantics. + void SetStream(gpuStream_t stream); + private: Place place_; + bool owned_stream_{true}; #ifdef PADDLE_WITH_HIP hipStream_t stream_{nullptr}; #else