提交 f9077aa4 编写于 作者: J jiweibo

update for subgraph. test=develop

上级 47ebc8c4
...@@ -357,7 +357,7 @@ void Predictor::GenRuntimeProgram() { ...@@ -357,7 +357,7 @@ void Predictor::GenRuntimeProgram() {
program_generated_ = true; program_generated_ = true;
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
if (!cuda_use_multi_stream_) { if (!cuda_use_multi_stream_) {
program_->UpdateCudaContext(cuda_exec_stream_, cuda_io_stream_); program_->UpdateCudaStream(cuda_exec_stream_, cuda_io_stream_);
} }
#endif #endif
} }
......
...@@ -158,10 +158,12 @@ class LITE_API Predictor { ...@@ -158,10 +158,12 @@ class LITE_API Predictor {
cuda_use_multi_stream_ = multi_stream; cuda_use_multi_stream_ = multi_stream;
} }
bool cuda_use_multi_stream() { return cuda_use_multi_stream_; } bool cuda_use_multi_stream() { return cuda_use_multi_stream_; }
void set_cuda_exec_stream(cudaStream_t stream) { cuda_exec_stream_ = stream; } void set_cuda_exec_stream(cudaStream_t* stream) {
void set_cuda_io_stream(cudaStream_t stream) { cuda_io_stream_ = stream; } cuda_exec_stream_ = stream;
cudaStream_t cuda_exec_stream() { return cuda_exec_stream_; } }
cudaStream_t cuda_io_stream() { return cuda_io_stream_; } void set_cuda_io_stream(cudaStream_t* stream) { cuda_io_stream_ = stream; }
cudaStream_t* cuda_exec_stream() { return cuda_exec_stream_; }
cudaStream_t* cuda_io_stream() { return cuda_io_stream_; }
#endif #endif
private: private:
...@@ -177,8 +179,8 @@ class LITE_API Predictor { ...@@ -177,8 +179,8 @@ class LITE_API Predictor {
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
bool cuda_use_multi_stream_{false}; bool cuda_use_multi_stream_{false};
cudaStream_t cuda_io_stream_; cudaStream_t* cuda_io_stream_{nullptr};
cudaStream_t cuda_exec_stream_; cudaStream_t* cuda_exec_stream_{nullptr};
#endif #endif
}; };
......
...@@ -135,14 +135,14 @@ void CxxPaddleApiImpl::InitCudaEnv(std::vector<std::string> *passes) { ...@@ -135,14 +135,14 @@ void CxxPaddleApiImpl::InitCudaEnv(std::vector<std::string> *passes) {
} }
void CxxPaddleApiImpl::SyncCudaInputs() { void CxxPaddleApiImpl::SyncCudaInputs() {
TargetWrapperCuda::RecordEvent(cuda_input_event_, cuda_io_stream_->stream()); TargetWrapperCuda::RecordEvent(cuda_input_event_, *cuda_io_stream_->stream());
if (cuda_use_multi_stream_) { if (cuda_use_multi_stream_) {
for (int i = 0; i < lite::kMaxStream; ++i) { for (int i = 0; i < lite::kMaxStream; ++i) {
TargetWrapperCuda::StreamSync(cuda_exec_streams_[i].stream(), TargetWrapperCuda::StreamSync(*cuda_exec_streams_[i].stream(),
cuda_input_event_); cuda_input_event_);
} }
} else { } else {
TargetWrapperCuda::StreamSync(cuda_exec_stream_->stream(), TargetWrapperCuda::StreamSync(*cuda_exec_stream_->stream(),
cuda_input_event_); cuda_input_event_);
} }
} }
...@@ -151,14 +151,14 @@ void CxxPaddleApiImpl::SyncCudaOutputs() { ...@@ -151,14 +151,14 @@ void CxxPaddleApiImpl::SyncCudaOutputs() {
if (cuda_use_multi_stream_) { if (cuda_use_multi_stream_) {
for (size_t i = 0; i < cuda_output_events_.size(); ++i) { for (size_t i = 0; i < cuda_output_events_.size(); ++i) {
TargetWrapperCuda::RecordEvent(cuda_output_events_[i], TargetWrapperCuda::RecordEvent(cuda_output_events_[i],
cuda_exec_streams_[i].stream()); *cuda_exec_streams_[i].stream());
TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(), TargetWrapperCuda::StreamSync(*cuda_io_stream_->stream(),
cuda_output_events_[i]); cuda_output_events_[i]);
} }
} else { } else {
TargetWrapperCuda::RecordEvent(cuda_output_events_[0], TargetWrapperCuda::RecordEvent(cuda_output_events_[0],
cuda_exec_stream_->stream()); *cuda_exec_stream_->stream());
TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(), TargetWrapperCuda::StreamSync(*cuda_io_stream_->stream(),
cuda_output_events_[0]); cuda_output_events_[0]);
} }
} }
...@@ -168,7 +168,7 @@ std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { ...@@ -168,7 +168,7 @@ std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
auto *x = raw_predictor_->GetInput(i); auto *x = raw_predictor_->GetInput(i);
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(x, cuda_io_stream_->stream())); new lite_api::Tensor(x, *cuda_io_stream_->stream()));
#else #else
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
#endif #endif
...@@ -179,7 +179,7 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput( ...@@ -179,7 +179,7 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
const auto *x = raw_predictor_->GetOutput(i); const auto *x = raw_predictor_->GetOutput(i);
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
return std::unique_ptr<lite_api::Tensor>( return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(x, cuda_io_stream_->stream())); new lite_api::Tensor(x, *cuda_io_stream_->stream()));
#else #else
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x)); return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
#endif #endif
......
...@@ -44,7 +44,7 @@ class CudaStreamGuard { ...@@ -44,7 +44,7 @@ class CudaStreamGuard {
lite::TargetWrapperCuda::DestroyStream(stream_); lite::TargetWrapperCuda::DestroyStream(stream_);
} }
} }
cudaStream_t stream() { return stream_; } cudaStream_t* stream() { return &stream_; }
bool owned() { return owned_; } bool owned() { return owned_; }
private: private:
......
...@@ -243,9 +243,9 @@ RuntimeProgram::RuntimeProgram( ...@@ -243,9 +243,9 @@ RuntimeProgram::RuntimeProgram(
} }
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
void RuntimeProgram::UpdateCudaContext(cudaStream_t exec, cudaStream_t io) { void RuntimeProgram::UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io) {
for (auto& inst : instructions_) { for (auto& inst : instructions_[kRootBlockIdx]) {
inst.UpdateCudaContext(exec, io); inst.UpdateCudaStream(exec, io);
} }
} }
#endif #endif
......
...@@ -140,10 +140,14 @@ struct Instruction { ...@@ -140,10 +140,14 @@ struct Instruction {
} }
} }
void Sync() const { kernel_->mutable_context()->As<CUDAContext>().Sync(); } void Sync() const { kernel_->mutable_context()->As<CUDAContext>().Sync(); }
void UpdateCudaContext(cudaStream_t exec, cudaStream_t io) { void UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io) {
if (kernel_->target() == TargetType::kCUDA) { if (kernel_->target() == TargetType::kCUDA) {
kernel_->mutable_context()->As<CUDAContext>().SetExecStream(exec); if (exec) {
kernel_->mutable_context()->As<CUDAContext>().SetIoStream(io); kernel_->mutable_context()->As<CUDAContext>().SetExecStream(*exec);
}
if (io) {
kernel_->mutable_context()->As<CUDAContext>().SetIoStream(*io);
}
} }
} }
#endif #endif
...@@ -245,9 +249,9 @@ class LITE_API RuntimeProgram { ...@@ -245,9 +249,9 @@ class LITE_API RuntimeProgram {
void SaveToProgram(std::shared_ptr<cpp::ProgramDesc> program_desc); void SaveToProgram(std::shared_ptr<cpp::ProgramDesc> program_desc);
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
// UpdateCudaContext will update the exec stream and io stream of all kernels // UpdateCudaStream will update the exec stream and io stream of all kernels
// in the program. // in the program.
void UpdateCudaContext(cudaStream_t exec, cudaStream_t io); void UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io);
#endif #endif
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册