未验证 提交 15ad7ee4 编写于 作者: W Wilber 提交者: GitHub

Support external stream. (#38373)

* support external stream.

* update

* update

* update
上级 b7bafee8
......@@ -24,6 +24,7 @@
#include <utility>
#include <vector>
#include "paddle/fluid//platform/device/gpu/gpu_types.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
......@@ -1043,6 +1044,20 @@ bool AnalysisPredictor::ZeroCopyRun() {
return true;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
if (stream != nullptr) {
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto gpu_place = BOOST_GET_CONST(paddle::platform::CUDAPlace, place_);
auto *dev_ctx = reinterpret_cast<paddle::platform::CUDADeviceContext *>(
pool.Get(gpu_place));
dev_ctx->SetThreadLocalStream(stream);
}
return ZeroCopyRun();
}
#endif
void AnalysisPredictor::CollectShapeRangeInfo() {
// if use gpu, sync first.
if (config_.use_gpu()) {
......@@ -1567,4 +1582,25 @@ Predictor *PredictorPool::Retrive(size_t idx) {
return preds_[idx - 1].get();
}
} // namespace services
namespace experimental {
// Note: Can only be used under thread_local semantics.
bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p,
cudaStream_t stream) {
#ifdef PADDLE_WITH_CUDA
auto pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get());
return pred->ExpRunWithExternalStream(stream);
#endif
return false;
}
bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p,
hipStream_t stream) {
#ifdef PADDLE_WITH_HIP
auto pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get());
return pred->ExpRunWithExternalStream(stream);
#endif
return false;
}
} // namespace experimental
} // namespace paddle_infer
......@@ -25,6 +25,7 @@
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_TESTING
......@@ -172,6 +173,11 @@ class AnalysisPredictor : public PaddlePredictor {
///
bool ZeroCopyRun() override;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Note: Can only be used under thread_local semantics.
bool ExpRunWithExternalStream(const gpuStream_t stream);
#endif
///
/// \brief Create feed fetch variables
///
......
......@@ -41,11 +41,27 @@ limitations under the License. */
/// \since 2.0.0-beta
///
// forward declation
using cudaStream_t = struct CUstream_st*;
using hipStream_t = struct ihipStream_t*;
namespace paddle_infer {
using PrecisionType = paddle::AnalysisConfig::Precision;
using Config = paddle::AnalysisConfig;
class Predictor;
namespace experimental {
class PD_INFER_DECL InternalUtils {
public:
// Note: Can only be used under thread_local semantics.
static bool RunWithExternalStream(paddle_infer::Predictor* pred,
cudaStream_t stream);
static bool RunWithExternalStream(paddle_infer::Predictor* pred,
hipStream_t stream);
};
} // namespace experimental
///
/// \class Predictor
///
......@@ -150,6 +166,7 @@ class PD_INFER_DECL Predictor {
private:
std::unique_ptr<paddle::PaddlePredictor> predictor_;
friend class paddle_infer::experimental::InternalUtils;
};
///
......
......@@ -488,6 +488,26 @@ CUDAContext::CUDAContext(const CUDAPlace& place,
#endif
}
void CUDAContext::SetStream(gpuStream_t stream) {
if (stream_->raw_stream() != stream) {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
DestoryCuSolverContext();
#endif
stream_->SetStream(stream);
InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
InitCuSolverContext();
#endif
}
}
CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
......
......@@ -334,6 +334,8 @@ class CUDAContext {
return old_stream_ptr;
}
void SetStream(gpuStream_t stream);
const gpuStream_t& RawStream() { return stream_->raw_stream(); }
#ifdef PADDLE_WITH_HIP
......@@ -616,6 +618,11 @@ class CUDADeviceContext : public DeviceContext {
return thread_ctx_.at(this);
}
// Note: Can only be used under thread_local semantics.
void SetThreadLocalStream(const gpuStream_t stream) {
thread_ctx_.at(this)->SetStream(stream);
}
private:
CUDAPlace place_;
std::shared_ptr<CUDAContext> default_ctx_;
......
......@@ -56,7 +56,7 @@ void CUDAStream::Destroy() {
CUDADeviceGuard guard(BOOST_GET_CONST(CUDAPlace, place_).device);
Wait();
WaitCallback();
if (stream_) {
if (stream_ && owned_stream_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
......@@ -92,6 +92,20 @@ void CUDAStream::Wait() const {
PADDLE_ENFORCE_GPU_SUCCESS(e_sync);
}
// Note: Can only be used under thread_local semantics.
void CUDAStream::SetStream(gpuStream_t stream) {
if (owned_stream_ && stream_) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(stream_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(stream_));
#endif
}
owned_stream_ = false;
stream_ = stream;
callback_manager_.reset(new StreamCallbackManager<gpuStream_t>(stream_));
}
CUDAStream* get_current_stream(int deviceId) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (deviceId == -1) {
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream_callback_manager.h"
......@@ -130,8 +131,12 @@ class CUDAStream final {
const Place& GetPlace() const { return place_; }
// Note: Can only be used under thread_local semantics.
void SetStream(gpuStream_t stream);
private:
Place place_;
bool owned_stream_{true};
#ifdef PADDLE_WITH_HIP
hipStream_t stream_{nullptr};
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册