From f9077aa4847c5e5308280167c33046e7b51c0ec5 Mon Sep 17 00:00:00 2001 From: jiweibo Date: Thu, 23 Jul 2020 05:33:37 +0000 Subject: [PATCH] update for subgraph. test=develop --- lite/api/cxx_api.cc | 2 +- lite/api/cxx_api.h | 14 ++++++++------ lite/api/cxx_api_impl.cc | 18 +++++++++--------- lite/backends/cuda/stream_guard.h | 2 +- lite/core/program.cc | 6 +++--- lite/core/program.h | 14 +++++++++----- 6 files changed, 31 insertions(+), 25 deletions(-) diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 46ef543d24..77342ffdb7 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -357,7 +357,7 @@ void Predictor::GenRuntimeProgram() { program_generated_ = true; #ifdef LITE_WITH_CUDA if (!cuda_use_multi_stream_) { - program_->UpdateCudaContext(cuda_exec_stream_, cuda_io_stream_); + program_->UpdateCudaStream(cuda_exec_stream_, cuda_io_stream_); } #endif } diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 4520c52ad8..89fbd99d8d 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -158,10 +158,12 @@ class LITE_API Predictor { cuda_use_multi_stream_ = multi_stream; } bool cuda_use_multi_stream() { return cuda_use_multi_stream_; } - void set_cuda_exec_stream(cudaStream_t stream) { cuda_exec_stream_ = stream; } - void set_cuda_io_stream(cudaStream_t stream) { cuda_io_stream_ = stream; } - cudaStream_t cuda_exec_stream() { return cuda_exec_stream_; } - cudaStream_t cuda_io_stream() { return cuda_io_stream_; } + void set_cuda_exec_stream(cudaStream_t* stream) { + cuda_exec_stream_ = stream; + } + void set_cuda_io_stream(cudaStream_t* stream) { cuda_io_stream_ = stream; } + cudaStream_t* cuda_exec_stream() { return cuda_exec_stream_; } + cudaStream_t* cuda_io_stream() { return cuda_io_stream_; } #endif private: @@ -177,8 +179,8 @@ class LITE_API Predictor { #ifdef LITE_WITH_CUDA bool cuda_use_multi_stream_{false}; - cudaStream_t cuda_io_stream_; - cudaStream_t cuda_exec_stream_; + cudaStream_t* cuda_io_stream_{nullptr}; + cudaStream_t* cuda_exec_stream_{nullptr}; #endif }; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 394bc6c0b3..142e417a7f 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -135,14 +135,14 @@ void CxxPaddleApiImpl::InitCudaEnv(std::vector *passes) { } void CxxPaddleApiImpl::SyncCudaInputs() { - TargetWrapperCuda::RecordEvent(cuda_input_event_, cuda_io_stream_->stream()); + TargetWrapperCuda::RecordEvent(cuda_input_event_, *cuda_io_stream_->stream()); if (cuda_use_multi_stream_) { for (int i = 0; i < lite::kMaxStream; ++i) { - TargetWrapperCuda::StreamSync(cuda_exec_streams_[i].stream(), + TargetWrapperCuda::StreamSync(*cuda_exec_streams_[i].stream(), cuda_input_event_); } } else { - TargetWrapperCuda::StreamSync(cuda_exec_stream_->stream(), + TargetWrapperCuda::StreamSync(*cuda_exec_stream_->stream(), cuda_input_event_); } } @@ -151,14 +151,14 @@ void CxxPaddleApiImpl::SyncCudaOutputs() { if (cuda_use_multi_stream_) { for (size_t i = 0; i < cuda_output_events_.size(); ++i) { TargetWrapperCuda::RecordEvent(cuda_output_events_[i], - cuda_exec_streams_[i].stream()); - TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(), + *cuda_exec_streams_[i].stream()); + TargetWrapperCuda::StreamSync(*cuda_io_stream_->stream(), cuda_output_events_[i]); } } else { TargetWrapperCuda::RecordEvent(cuda_output_events_[0], - cuda_exec_stream_->stream()); - TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(), + *cuda_exec_stream_->stream()); + TargetWrapperCuda::StreamSync(*cuda_io_stream_->stream(), cuda_output_events_[0]); } } @@ -168,7 +168,7 @@ std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { auto *x = raw_predictor_->GetInput(i); #ifdef LITE_WITH_CUDA return std::unique_ptr( - new lite_api::Tensor(x, cuda_io_stream_->stream())); + new lite_api::Tensor(x, *cuda_io_stream_->stream())); #else return std::unique_ptr(new lite_api::Tensor(x)); #endif @@ -179,7 +179,7 @@ std::unique_ptr CxxPaddleApiImpl::GetOutput( const auto *x = raw_predictor_->GetOutput(i); #ifdef LITE_WITH_CUDA return std::unique_ptr( - new lite_api::Tensor(x, cuda_io_stream_->stream())); + new lite_api::Tensor(x, *cuda_io_stream_->stream())); #else return std::unique_ptr(new lite_api::Tensor(x)); #endif diff --git a/lite/backends/cuda/stream_guard.h b/lite/backends/cuda/stream_guard.h index e59d8514c5..69a5015a67 100644 --- a/lite/backends/cuda/stream_guard.h +++ b/lite/backends/cuda/stream_guard.h @@ -44,7 +44,7 @@ class CudaStreamGuard { lite::TargetWrapperCuda::DestroyStream(stream_); } } - cudaStream_t stream() { return stream_; } + cudaStream_t* stream() { return &stream_; } bool owned() { return owned_; } private: diff --git a/lite/core/program.cc b/lite/core/program.cc index 289787d3bf..4d79c46213 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -243,9 +243,9 @@ RuntimeProgram::RuntimeProgram( } #ifdef LITE_WITH_CUDA -void RuntimeProgram::UpdateCudaContext(cudaStream_t exec, cudaStream_t io) { - for (auto& inst : instructions_) { - inst.UpdateCudaContext(exec, io); +void RuntimeProgram::UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io) { + for (auto& inst : instructions_[kRootBlockIdx]) { + inst.UpdateCudaStream(exec, io); } } #endif diff --git a/lite/core/program.h b/lite/core/program.h index 50c4bb37d6..1adf0a1710 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -140,10 +140,14 @@ struct Instruction { } } void Sync() const { kernel_->mutable_context()->As().Sync(); } - void UpdateCudaContext(cudaStream_t exec, cudaStream_t io) { + void UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io) { if (kernel_->target() == TargetType::kCUDA) { - kernel_->mutable_context()->As().SetExecStream(exec); - kernel_->mutable_context()->As().SetIoStream(io); + if (exec) { + kernel_->mutable_context()->As().SetExecStream(*exec); + } + if (io) { + kernel_->mutable_context()->As().SetIoStream(*io); + } } } #endif @@ -245,9 +249,9 @@ class LITE_API RuntimeProgram { void SaveToProgram(std::shared_ptr program_desc); #ifdef LITE_WITH_CUDA - // UpdateCudaContext will update the exec stream and io stream of all kernels + // UpdateCudaStream will update the exec stream and io stream of all kernels // in the program. - void UpdateCudaContext(cudaStream_t exec, cudaStream_t io); + void UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io); #endif private: -- GitLab