提交 f9077aa4 编写于 作者: J jiweibo

update for subgraph. test=develop

上级 47ebc8c4
......@@ -357,7 +357,7 @@ void Predictor::GenRuntimeProgram() {
program_generated_ = true;
#ifdef LITE_WITH_CUDA
if (!cuda_use_multi_stream_) {
program_->UpdateCudaContext(cuda_exec_stream_, cuda_io_stream_);
program_->UpdateCudaStream(cuda_exec_stream_, cuda_io_stream_);
}
#endif
}
......
......@@ -158,10 +158,12 @@ class LITE_API Predictor {
cuda_use_multi_stream_ = multi_stream;
}
bool cuda_use_multi_stream() { return cuda_use_multi_stream_; }
void set_cuda_exec_stream(cudaStream_t stream) { cuda_exec_stream_ = stream; }
void set_cuda_io_stream(cudaStream_t stream) { cuda_io_stream_ = stream; }
cudaStream_t cuda_exec_stream() { return cuda_exec_stream_; }
cudaStream_t cuda_io_stream() { return cuda_io_stream_; }
void set_cuda_exec_stream(cudaStream_t* stream) {
cuda_exec_stream_ = stream;
}
void set_cuda_io_stream(cudaStream_t* stream) { cuda_io_stream_ = stream; }
cudaStream_t* cuda_exec_stream() { return cuda_exec_stream_; }
cudaStream_t* cuda_io_stream() { return cuda_io_stream_; }
#endif
private:
......@@ -177,8 +179,8 @@ class LITE_API Predictor {
#ifdef LITE_WITH_CUDA
bool cuda_use_multi_stream_{false};
cudaStream_t cuda_io_stream_;
cudaStream_t cuda_exec_stream_;
cudaStream_t* cuda_io_stream_{nullptr};
cudaStream_t* cuda_exec_stream_{nullptr};
#endif
};
......
......@@ -135,14 +135,14 @@ void CxxPaddleApiImpl::InitCudaEnv(std::vector<std::string> *passes) {
}
void CxxPaddleApiImpl::SyncCudaInputs() {
TargetWrapperCuda::RecordEvent(cuda_input_event_, cuda_io_stream_->stream());
TargetWrapperCuda::RecordEvent(cuda_input_event_, *cuda_io_stream_->stream());
if (cuda_use_multi_stream_) {
for (int i = 0; i < lite::kMaxStream; ++i) {
TargetWrapperCuda::StreamSync(cuda_exec_streams_[i].stream(),
TargetWrapperCuda::StreamSync(*cuda_exec_streams_[i].stream(),
cuda_input_event_);
}
} else {
TargetWrapperCuda::StreamSync(cuda_exec_stream_->stream(),
TargetWrapperCuda::StreamSync(*cuda_exec_stream_->stream(),
cuda_input_event_);
}
}
......@@ -151,14 +151,14 @@ void CxxPaddleApiImpl::SyncCudaOutputs() {
if (cuda_use_multi_stream_) {
for (size_t i = 0; i < cuda_output_events_.size(); ++i) {
TargetWrapperCuda::RecordEvent(cuda_output_events_[i],
cuda_exec_streams_[i].stream());
TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(),
*cuda_exec_streams_[i].stream());
TargetWrapperCuda::StreamSync(*cuda_io_stream_->stream(),
cuda_output_events_[i]);
}
} else {
TargetWrapperCuda::RecordEvent(cuda_output_events_[0],
cuda_exec_stream_->stream());
TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(),
*cuda_exec_stream_->stream());
TargetWrapperCuda::StreamSync(*cuda_io_stream_->stream(),
cuda_output_events_[0]);
}
}
......@@ -168,7 +168,7 @@ std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
auto *x = raw_predictor_->GetInput(i);
#ifdef LITE_WITH_CUDA
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(x, cuda_io_stream_->stream()));
new lite_api::Tensor(x, *cuda_io_stream_->stream()));
#else
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
#endif
......@@ -179,7 +179,7 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
const auto *x = raw_predictor_->GetOutput(i);
#ifdef LITE_WITH_CUDA
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(x, cuda_io_stream_->stream()));
new lite_api::Tensor(x, *cuda_io_stream_->stream()));
#else
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
#endif
......
......@@ -44,7 +44,7 @@ class CudaStreamGuard {
lite::TargetWrapperCuda::DestroyStream(stream_);
}
}
cudaStream_t stream() { return stream_; }
cudaStream_t* stream() { return &stream_; }
bool owned() { return owned_; }
private:
......
......@@ -243,9 +243,9 @@ RuntimeProgram::RuntimeProgram(
}
#ifdef LITE_WITH_CUDA
void RuntimeProgram::UpdateCudaContext(cudaStream_t exec, cudaStream_t io) {
for (auto& inst : instructions_) {
inst.UpdateCudaContext(exec, io);
void RuntimeProgram::UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io) {
for (auto& inst : instructions_[kRootBlockIdx]) {
inst.UpdateCudaStream(exec, io);
}
}
#endif
......
......@@ -140,10 +140,14 @@ struct Instruction {
}
}
void Sync() const { kernel_->mutable_context()->As<CUDAContext>().Sync(); }
void UpdateCudaContext(cudaStream_t exec, cudaStream_t io) {
void UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io) {
if (kernel_->target() == TargetType::kCUDA) {
kernel_->mutable_context()->As<CUDAContext>().SetExecStream(exec);
kernel_->mutable_context()->As<CUDAContext>().SetIoStream(io);
if (exec) {
kernel_->mutable_context()->As<CUDAContext>().SetExecStream(*exec);
}
if (io) {
kernel_->mutable_context()->As<CUDAContext>().SetIoStream(*io);
}
}
}
#endif
......@@ -245,9 +249,9 @@ class LITE_API RuntimeProgram {
void SaveToProgram(std::shared_ptr<cpp::ProgramDesc> program_desc);
#ifdef LITE_WITH_CUDA
// UpdateCudaContext will update the exec stream and io stream of all kernels
// UpdateCudaStream will update the exec stream and io stream of all kernels
// in the program.
void UpdateCudaContext(cudaStream_t exec, cudaStream_t io);
void UpdateCudaStream(cudaStream_t* exec, cudaStream_t* io);
#endif
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册