diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 40bc64442043eab5905a0941533b1f90cd5bf1e9..dd2fd1ed23fa58e6f8de7b65294a6fc62a3bfcce 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -355,7 +355,7 @@ void Predictor::GenRuntimeProgram() { program_generated_ = true; #ifdef LITE_WITH_CUDA if (!cuda_use_multi_stream_) { - program_->UpdateContext(cuda_exec_stream_, cuda_io_stream_); + program_->UpdateCudaContext(cuda_exec_stream_, cuda_io_stream_); } #endif } diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index c452e1b6e42ee7066abb4fedf7585c1ea4331bc3..004fbae071412faeee60d18ad73594956c097297 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -29,7 +29,7 @@ #ifdef LITE_WITH_CUDA #include "lite/backends/cuda/cuda_utils.h" -#include "lite/backends/cuda/stream_wrapper.h" +#include "lite/backends/cuda/stream_guard.h" #endif namespace paddle { @@ -254,12 +254,12 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { #ifdef LITE_WITH_CUDA bool cuda_use_multi_stream_{false}; - std::unique_ptr cuda_io_stream_; - std::unique_ptr cuda_exec_stream_; + std::unique_ptr cuda_io_stream_; + std::unique_ptr cuda_exec_stream_; cudaEvent_t cuda_input_event_; std::vector cuda_output_events_; // only used for multi exec stream mode. - std::vector cuda_exec_streams_; + std::vector cuda_exec_streams_; #endif }; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index e26a3cf4d496118426b8acc1ab49cb9bb24c8e81..cce790544481c1940bd85762aa8259fff4e85b73 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -97,14 +97,14 @@ void CxxPaddleApiImpl::InitCudaEnv(std::vector *passes) { // init two streams for each predictor. if (config_.cuda_exec_stream()) { cuda_exec_stream_.reset( - new lite::StreamWrapper(*config_.cuda_exec_stream())); + new lite::CudaStreamGuard(*config_.cuda_exec_stream())); } else { - cuda_exec_stream_.reset(new lite::StreamWrapper()); + cuda_exec_stream_.reset(new lite::CudaStreamGuard()); } if (config_.cuda_io_stream()) { - cuda_io_stream_.reset(new lite::StreamWrapper(*config_.cuda_io_stream())); + cuda_io_stream_.reset(new lite::CudaStreamGuard(*config_.cuda_io_stream())); } else { - cuda_io_stream_.reset(new lite::StreamWrapper()); + cuda_io_stream_.reset(new lite::CudaStreamGuard()); } raw_predictor_->set_cuda_exec_stream(cuda_exec_stream_->stream()); diff --git a/lite/api/test_resnet50_lite_cuda.cc b/lite/api/test_resnet50_lite_cuda.cc index 138e930ba958122acff7c0e0694f5d5d178a3f92..fe9232da9d1c01cedf1e45061e8f43383188e342 100644 --- a/lite/api/test_resnet50_lite_cuda.cc +++ b/lite/api/test_resnet50_lite_cuda.cc @@ -29,7 +29,7 @@ namespace paddle { namespace lite { -void RunModel(lite_api::CxxConfig config) { +void RunModel(const lite_api::CxxConfig& config) { auto predictor = lite_api::CreatePaddlePredictor(config); const int batch_size = 4; const int channels = 3; diff --git a/lite/backends/cuda/CMakeLists.txt b/lite/backends/cuda/CMakeLists.txt index be779c6d2d3dd23483983f146532c5a8573392c2..59db53097dbaebf545f833b0b3d23b4aabd97c37 100644 --- a/lite/backends/cuda/CMakeLists.txt +++ b/lite/backends/cuda/CMakeLists.txt @@ -9,6 +9,6 @@ nv_library(cuda_blas SRCS blas.cc DEPS ${cuda_deps}) nv_library(nvtx_wrapper SRCS nvtx_wrapper DEPS ${cuda_deps}) lite_cc_library(cuda_context SRCS context.cc DEPS device_info) -lite_cc_library(stream_wrapper SRCS stream_wrapper.cc DEPS target_wrapper_cuda ${cuda_deps}) +lite_cc_library(stream_guard SRCS stream_guard.cc DEPS target_wrapper_cuda ${cuda_deps}) add_subdirectory(math) diff --git a/lite/backends/cuda/stream_wrapper.cc b/lite/backends/cuda/stream_guard.cc similarity index 94% rename from lite/backends/cuda/stream_wrapper.cc rename to lite/backends/cuda/stream_guard.cc index 75b1f944b1078905c689817f4506548722109726..02f8b9d85e5bd599ea08dcbef2062f468efac74b 100644 --- a/lite/backends/cuda/stream_wrapper.cc +++ b/lite/backends/cuda/stream_guard.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/cuda/stream_wrapper.h" +#include "lite/backends/cuda/stream_guard.h" #include "lite/backends/cuda/cuda_utils.h" namespace paddle { diff --git a/lite/backends/cuda/stream_wrapper.h b/lite/backends/cuda/stream_guard.h similarity index 63% rename from lite/backends/cuda/stream_wrapper.h rename to lite/backends/cuda/stream_guard.h index ac62ea99055843ad42bbfe6d0cfdcead6fffabcd..e59d8514c58c35b9d409bf979c1ccbebcd560113 100644 --- a/lite/backends/cuda/stream_wrapper.h +++ b/lite/backends/cuda/stream_guard.h @@ -21,24 +21,35 @@ namespace paddle { namespace lite { -class StreamWrapper { +// CudaStreamGuard is a encapsulation of cudaStream_t, which can accept external +// stream or internally created stream +// +// std::unique_ptr sm; +// +// external stream: exec_stream +// sm.reset(new CudaStreamGuard(exec_stream)); +// internal stream +// sm.reset(new CudaStreamGuard()); +// get cudaStream_t +// sm->stream(); +class CudaStreamGuard { public: - explicit StreamWrapper(cudaStream_t stream) - : stream_(stream), owner_(false) {} - StreamWrapper() : owner_(true) { + explicit CudaStreamGuard(cudaStream_t stream) + : stream_(stream), owned_(false) {} + CudaStreamGuard() : owned_(true) { lite::TargetWrapperCuda::CreateStream(&stream_); } - ~StreamWrapper() { - if (owner_) { + ~CudaStreamGuard() { + if (owned_) { lite::TargetWrapperCuda::DestroyStream(stream_); } } cudaStream_t stream() { return stream_; } - bool owner() { return owner_; } + bool owned() { return owned_; } private: cudaStream_t stream_; - bool owner_; + bool owned_{false}; }; } // namespace lite diff --git a/lite/core/program.cc b/lite/core/program.cc index 6b1a93e47b93983fd9f1ede0d79880093354cf07..5aec6ee229d19ba164f10862619493253c21f541 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -71,7 +71,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { std::map origin_var_maps; auto& main_block = *desc->GetBlock(0); auto var_size = main_block.VarsSize(); - for (int i = 0; i < static_cast(var_size); i++) { + for (size_t i = 0; i < var_size; i++) { auto v = main_block.GetVar(i); auto name = v->Name(); origin_var_maps.emplace(name, *v); @@ -144,9 +144,9 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { } #ifdef LITE_WITH_CUDA -void RuntimeProgram::UpdateContext(cudaStream_t exec, cudaStream_t io) { +void RuntimeProgram::UpdateCudaContext(cudaStream_t exec, cudaStream_t io) { for (auto& inst : instructions_) { - inst.UpdateContext(exec, io); + inst.UpdateCudaContext(exec, io); } } #endif diff --git a/lite/core/program.h b/lite/core/program.h index ecd264cb9be54e53822dc3dfd79924647707f9d8..d47bb1c09e039af1abcd6c2d94b9aa9f3063f977 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -129,7 +129,7 @@ struct Instruction { } } void Sync() const { kernel_->mutable_context()->As().Sync(); } - void UpdateContext(cudaStream_t exec, cudaStream_t io) { + void UpdateCudaContext(cudaStream_t exec, cudaStream_t io) { if (kernel_->target() == TargetType::kCUDA) { kernel_->mutable_context()->As().SetExecStream(exec); kernel_->mutable_context()->As().SetIoStream(io); @@ -223,9 +223,9 @@ class LITE_API RuntimeProgram { void UpdateVarsOfProgram(cpp::ProgramDesc* desc); #ifdef LITE_WITH_CUDA - // UpdateContext will update the exec stream and io stream of all kernels in - // the program. - void UpdateContext(cudaStream_t exec, cudaStream_t io); + // UpdateCudaContext will update the exec stream and io stream of all kernels + // in the program. + void UpdateCudaContext(cudaStream_t exec, cudaStream_t io); #endif private: