提交 dcde5c7c 编写于 作者: J jiweibo

update

上级 37f606d2
......@@ -354,8 +354,8 @@ void Predictor::GenRuntimeProgram() {
CHECK_EQ(exec_scope_, program_->exec_scope());
program_generated_ = true;
#ifdef LITE_WITH_CUDA
if (!multi_stream_) {
program_->UpdateContext(exec_stream_, io_stream_);
if (!cuda_use_multi_stream_) {
program_->UpdateContext(cuda_exec_stream_, cuda_io_stream_);
}
#endif
}
......
......@@ -29,6 +29,7 @@
#ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/stream_wrapper.h"
#endif
namespace paddle {
......@@ -154,22 +155,15 @@ class LITE_API Predictor {
bool record_info = false);
void SaveOpKernelInfo(const std::string& model_dir);
// #ifdef LITE_WITH_TRAIN
// void Run(const std::vector<framework::Tensor>& tensors) {
// FeedVars(tensors);
// program_->Run();
// }
// void FeedVars(const std::vector<framework::Tensor>& tensors);
// #endif
#ifdef LITE_WITH_CUDA
void set_multi_stream(bool multi_stream) { multi_stream_ = multi_stream; }
bool multi_stream() { return multi_stream_; }
void set_exec_stream(cudaStream_t* stream) { exec_stream_ = stream; }
void set_io_stream(cudaStream_t* stream) { io_stream_ = stream; }
const cudaStream_t& exec_stream() { return *exec_stream_; }
const cudaStream_t& io_stream() { return *io_stream_; }
void set_cuda_use_multi_stream(bool multi_stream) {
cuda_use_multi_stream_ = multi_stream;
}
bool cuda_use_multi_stream() { return cuda_use_multi_stream_; }
void set_cuda_exec_stream(cudaStream_t stream) { cuda_exec_stream_ = stream; }
void set_cuda_io_stream(cudaStream_t stream) { cuda_io_stream_ = stream; }
cudaStream_t cuda_exec_stream() { return cuda_exec_stream_; }
cudaStream_t cuda_io_stream() { return cuda_io_stream_; }
#endif
private:
......@@ -182,10 +176,11 @@ class LITE_API Predictor {
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
std::vector<Place> valid_places_;
#ifdef LITE_WITH_CUDA
bool multi_stream_{false};
cudaStream_t* io_stream_{nullptr};
cudaStream_t* exec_stream_{nullptr};
bool cuda_use_multi_stream_{false};
cudaStream_t cuda_io_stream_;
cudaStream_t cuda_exec_stream_;
#endif
};
......@@ -247,8 +242,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
void InitCudaEnv(std::vector<std::string>* passes);
// Due to the asynchronous nature of cuda kernel execution, synchronization is
// required before setting input and getting output.
void SyncInputs();
void SyncOutputs();
void SyncCudaInputs();
void SyncCudaOutputs();
#endif
private:
......@@ -256,76 +251,17 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
lite_api::CxxConfig config_;
std::mutex mutex_;
bool status_is_cloned_;
#ifdef LITE_WITH_CUDA
bool multi_stream_{false};
std::shared_ptr<cudaStream_t> io_stream_;
std::shared_ptr<cudaStream_t> exec_stream_;
cudaEvent_t input_event_;
std::vector<cudaEvent_t> output_events_;
// only for multi exec stream mode.
std::vector<cudaStream_t*> exec_streams_;
#endif
};
/*
* An executor for training.
*
* Usage:
*
* CXXTrainer trainer(...);
* trainer.RunStartupProgram(...);
* auto exe = BuildMainProgramExecutor(...);
*
* for (auto& epoch : epoches) {
* auto* tensor0 = exe.GetInput(...);
* // fill data for tensor0
* exe.Run();
* }
#ifdef LITE_WITH_X86
class LITE_API CXXTrainer {
public:
CXXTrainer(const std::shared_ptr<lite::Scope>& root_scope,
const std::vector<Place>& valid_places)
: scope_(root_scope),
valid_places_(valid_places),
main_program_executor_(Predictor(scope_)) {}
// Build the RuntimeProgram cache for the main program. The cache will run
// multiple times for the epoches.
// NOTE Just support to execute the 0-th block currently.
Predictor& BuildMainProgramExecutor(const framework::proto::ProgramDesc& desc,
int block_id = 0) {
main_program_executor_.Build(desc, valid_places_);
return main_program_executor_;
}
#ifdef LITE_WITH_TRAIN
Predictor& BuildMainProgramExecutor(framework::ProgramDesc& desc) { // NOLINT
return BuildMainProgramExecutor(*desc.Proto());
}
void RunStartupProgram(framework::ProgramDesc& desc) { // NOLINT
RunStartupProgram(*desc.Proto());
}
#ifdef LITE_WITH_CUDA
bool cuda_use_multi_stream_{false};
std::unique_ptr<lite::StreamWrapper> cuda_io_stream_;
std::unique_ptr<lite::StreamWrapper> cuda_exec_stream_;
cudaEvent_t cuda_input_event_;
std::vector<cudaEvent_t> cuda_output_events_;
// only used for multi exec stream mode.
std::vector<lite::StreamWrapper> cuda_exec_streams_;
#endif
// Run the startup program. It just executes once, no cache needed.
void RunStartupProgram(const framework::proto::ProgramDesc& desc,
int block_id = 0) {
Predictor exe(scope_);
exe.Build(desc, valid_places_);
exe.Run();
}
private:
std::shared_ptr<lite::Scope> scope_;
std::vector<Place> valid_places_;
// The training program.
Predictor main_program_executor_;
};
#endif
*/
} // namespace lite
} // namespace paddle
......@@ -36,6 +36,7 @@ namespace lite {
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config_ = config;
config_.check_valid();
auto places = config.valid_places();
std::vector<std::string> passes = config.get_passes_internal();
#ifdef LITE_WITH_CUDA
......@@ -94,65 +95,69 @@ void CxxPaddleApiImpl::InitCudaEnv(std::vector<std::string> *passes) {
Env<TARGET(kCUDA)>::Init();
// init two streams for each predictor.
if (config_.exec_stream()) {
exec_stream_ = config_.exec_stream();
if (config_.cuda_exec_stream()) {
cuda_exec_stream_.reset(
new lite::StreamWrapper(*config_.cuda_exec_stream()));
} else {
exec_stream_ = std::make_shared<cudaStream_t>();
TargetWrapperCuda::CreateStream(exec_stream_.get());
cuda_exec_stream_.reset(new lite::StreamWrapper());
}
if (config_.io_stream()) {
io_stream_ = config_.io_stream();
if (config_.cuda_io_stream()) {
cuda_io_stream_.reset(new lite::StreamWrapper(*config_.cuda_io_stream()));
} else {
io_stream_ = std::make_shared<cudaStream_t>();
TargetWrapperCuda::CreateStream(io_stream_.get());
cuda_io_stream_.reset(new lite::StreamWrapper());
}
raw_predictor_->set_exec_stream(exec_stream_.get());
raw_predictor_->set_io_stream(io_stream_.get());
raw_predictor_->set_cuda_exec_stream(cuda_exec_stream_->stream());
raw_predictor_->set_cuda_io_stream(cuda_io_stream_->stream());
// init sync events.
if (config_.multi_stream()) {
multi_stream_ = true;
raw_predictor_->set_multi_stream(multi_stream_);
if (config_.cuda_use_multi_stream()) {
cuda_use_multi_stream_ = true;
raw_predictor_->set_cuda_use_multi_stream(cuda_use_multi_stream_);
passes->push_back("multi_stream_analysis_pass");
VLOG(3) << "add pass: " << (*passes)[0];
Env<TargetType::kCUDA>::Devs &devs = Env<TargetType::kCUDA>::Global();
int dev_id = TargetWrapperCuda::GetCurDevice();
for (size_t i = 0; i < lite::kMaxStream; ++i) {
exec_streams_.push_back(
const_cast<cudaStream_t *>(&devs[dev_id].exec_streams()[i]));
cuda_exec_streams_.emplace_back(devs[dev_id].exec_streams()[i]);
cudaEvent_t out_event;
TargetWrapperCuda::CreateEventWithFlags(&out_event);
output_events_.push_back(out_event);
cuda_output_events_.push_back(out_event);
}
} else {
cudaEvent_t out_event;
TargetWrapperCuda::CreateEventWithFlags(&out_event);
output_events_.push_back(out_event);
cuda_output_events_.push_back(out_event);
}
TargetWrapperCuda::CreateEventWithFlags(&input_event_);
TargetWrapperCuda::CreateEventWithFlags(&cuda_input_event_);
}
void CxxPaddleApiImpl::SyncInputs() {
TargetWrapperCuda::RecordEvent(input_event_, *io_stream_);
if (multi_stream_) {
void CxxPaddleApiImpl::SyncCudaInputs() {
TargetWrapperCuda::RecordEvent(cuda_input_event_, cuda_io_stream_->stream());
if (cuda_use_multi_stream_) {
for (int i = 0; i < lite::kMaxStream; ++i) {
TargetWrapperCuda::StreamSync(*exec_streams_[i], input_event_);
TargetWrapperCuda::StreamSync(cuda_exec_streams_[i].stream(),
cuda_input_event_);
}
} else {
TargetWrapperCuda::StreamSync(*exec_stream_, input_event_);
TargetWrapperCuda::StreamSync(cuda_exec_stream_->stream(),
cuda_input_event_);
}
}
void CxxPaddleApiImpl::SyncOutputs() {
if (multi_stream_) {
for (size_t i = 0; i < output_events_.size(); ++i) {
TargetWrapperCuda::RecordEvent(output_events_[i], *exec_streams_[i]);
TargetWrapperCuda::StreamSync(*io_stream_, output_events_[i]);
void CxxPaddleApiImpl::SyncCudaOutputs() {
if (cuda_use_multi_stream_) {
for (size_t i = 0; i < cuda_output_events_.size(); ++i) {
TargetWrapperCuda::RecordEvent(cuda_output_events_[i],
cuda_exec_streams_[i].stream());
TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(),
cuda_output_events_[i]);
}
} else {
TargetWrapperCuda::RecordEvent(output_events_[0], *exec_stream_);
TargetWrapperCuda::StreamSync(*io_stream_, output_events_[0]);
TargetWrapperCuda::RecordEvent(cuda_output_events_[0],
cuda_exec_stream_->stream());
TargetWrapperCuda::StreamSync(cuda_io_stream_->stream(),
cuda_output_events_[0]);
}
}
#endif
......@@ -161,7 +166,7 @@ std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
auto *x = raw_predictor_->GetInput(i);
#ifdef LITE_WITH_CUDA
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(x, io_stream_.get()));
new lite_api::Tensor(x, cuda_io_stream_->stream()));
#else
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
#endif
......@@ -172,7 +177,7 @@ std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
const auto *x = raw_predictor_->GetOutput(i);
#ifdef LITE_WITH_CUDA
return std::unique_ptr<lite_api::Tensor>(
new lite_api::Tensor(x, io_stream_.get()));
new lite_api::Tensor(x, cuda_io_stream_->stream()));
#else
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
#endif
......@@ -195,13 +200,13 @@ void CxxPaddleApiImpl::Run() {
lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif
#ifdef LITE_WITH_CUDA
SyncInputs();
SyncCudaInputs();
#endif
raw_predictor_->Run();
#ifdef LITE_WITH_CUDA
SyncOutputs();
SyncCudaOutputs();
#endif
}
......@@ -250,9 +255,9 @@ void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir,
CxxPaddleApiImpl::~CxxPaddleApiImpl() {
#ifdef LITE_WITH_CUDA
TargetWrapperCuda::DestroyEvent(input_event_);
for (size_t i = 0; i < output_events_.size(); ++i) {
TargetWrapperCuda::DestroyEvent(output_events_[i]);
TargetWrapperCuda::DestroyEvent(cuda_input_event_);
for (size_t i = 0; i < cuda_output_events_.size(); ++i) {
TargetWrapperCuda::DestroyEvent(cuda_output_events_[i]);
}
#endif
}
......
......@@ -41,10 +41,10 @@ Tensor::Tensor(void *raw) : raw_tensor_(raw) {}
Tensor::Tensor(const void *raw) { raw_tensor_ = const_cast<void *>(raw); }
#ifdef LITE_WITH_CUDA
Tensor::Tensor(void *raw, cudaStream_t *stream)
: raw_tensor_(raw), io_stream_(stream) {}
Tensor::Tensor(void *raw, cudaStream_t stream)
: raw_tensor_(raw), cuda_io_stream_(stream) {}
Tensor::Tensor(const void *raw, cudaStream_t *stream) : io_stream_(stream) {
Tensor::Tensor(const void *raw, cudaStream_t stream) : cuda_io_stream_(stream) {
raw_tensor_ = const_cast<void *>(raw);
}
#endif
......@@ -112,8 +112,11 @@ void Tensor::CopyFromCpu(const T *src_data) {
data, src_data, num * sizeof(T), lite::IoDirection::HtoH);
} else if (type == TargetType::kCUDA) {
#ifdef LITE_WITH_CUDA
lite::TargetWrapperCuda::MemcpyAsync(
data, src_data, num * sizeof(T), lite::IoDirection::HtoD, *io_stream_);
lite::TargetWrapperCuda::MemcpyAsync(data,
src_data,
num * sizeof(T),
lite::IoDirection::HtoD,
cuda_io_stream_);
#else
LOG(FATAL) << "Please compile the lib with CUDA.";
#endif
......@@ -139,9 +142,12 @@ void Tensor::CopyToCpu(T *data) const {
data, src_data, num * sizeof(T), lite::IoDirection::HtoH);
} else if (type == TargetType::kCUDA) {
#ifdef LITE_WITH_CUDA
lite::TargetWrapperCuda::MemcpyAsync(
data, src_data, num * sizeof(T), lite::IoDirection::DtoH, *io_stream_);
lite::TargetWrapperCuda::StreamSync(*io_stream_);
lite::TargetWrapperCuda::MemcpyAsync(data,
src_data,
num * sizeof(T),
lite::IoDirection::DtoH,
cuda_io_stream_);
lite::TargetWrapperCuda::StreamSync(cuda_io_stream_);
#else
LOG(FATAL) << "Please compile the lib with CUDA.";
#endif
......
......@@ -23,6 +23,7 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle_place.h" // NOLINT
#ifdef LITE_WITH_CUDA
......@@ -67,14 +68,14 @@ struct LITE_API Tensor {
void SetLoD(const lod_t& lod);
#ifdef LITE_WITH_CUDA
explicit Tensor(void* raw, cudaStream_t* stream);
explicit Tensor(const void* raw, cudaStream_t* stream);
explicit Tensor(void* raw, cudaStream_t stream);
explicit Tensor(const void* raw, cudaStream_t stream);
#endif
private:
void* raw_tensor_;
#ifdef LITE_WITH_CUDA
cudaStream_t* io_stream_{nullptr};
cudaStream_t cuda_io_stream_;
#endif
};
......@@ -166,11 +167,13 @@ class LITE_API CxxConfig : public ConfigBase {
#ifdef LITE_WITH_X86
int x86_math_library_math_threads_ = 1;
#endif
#ifdef LITE_WITH_CUDA
bool multi_stream_{false};
std::shared_ptr<cudaStream_t> exec_stream_;
std::shared_ptr<cudaStream_t> io_stream_;
bool cuda_use_multi_stream_{false};
cudaStream_t* cuda_exec_stream_{nullptr};
cudaStream_t* cuda_io_stream_{nullptr};
#endif
#ifdef LITE_WITH_MLU
lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLUCoreVersion::MLU_270};
int mlu_core_number_{1};
......@@ -205,6 +208,19 @@ class LITE_API CxxConfig : public ConfigBase {
std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; }
bool model_from_memory() const { return model_from_memory_; }
bool check_valid() const {
#ifdef LITE_WITH_CUDA
if (cuda_use_multi_stream_ && (cuda_exec_stream_ || cuda_io_stream_)) {
LOG(FATAL) << "Can not set cuda_use_multi_stream and cuda_exec/io_stream "
"simultaneously. cuda_use_multi_stream is only valid in "
"single thread, it is designed to started multiple streams "
"within a model. cuda_exec/io_stream is to set an exec/io "
"stream for each thread, that is, each thread has its own "
"exec/io stream";
}
#endif
return true;
}
#ifdef LITE_WITH_X86
void set_x86_math_library_num_threads(int threads) {
......@@ -214,17 +230,23 @@ class LITE_API CxxConfig : public ConfigBase {
return x86_math_library_math_threads_;
}
#endif
#ifdef LITE_WITH_CUDA
void set_multi_stream(bool multi_stream) { multi_stream_ = multi_stream; }
bool multi_stream() const { return multi_stream_; }
void set_exec_stream(std::shared_ptr<cudaStream_t> exec_stream) {
exec_stream_ = exec_stream;
void set_cuda_use_multi_stream(bool use_multi_stream) {
cuda_use_multi_stream_ = use_multi_stream;
}
void set_io_stream(std::shared_ptr<cudaStream_t> io_stream) {
io_stream_ = io_stream;
bool cuda_use_multi_stream() const { return cuda_use_multi_stream_; }
void set_cuda_stream(cudaStream_t* exec_stream = nullptr,
cudaStream_t* io_stream = nullptr) {
if (exec_stream) {
cuda_exec_stream_ = exec_stream;
}
if (io_stream) {
cuda_io_stream_ = io_stream;
}
}
std::shared_ptr<cudaStream_t> exec_stream() { return exec_stream_; }
std::shared_ptr<cudaStream_t> io_stream() { return io_stream_; }
cudaStream_t* cuda_exec_stream() { return cuda_exec_stream_; }
cudaStream_t* cuda_io_stream() { return cuda_io_stream_; }
#endif
#ifdef LITE_WITH_MLU
......
......@@ -95,45 +95,45 @@ TEST(Resnet50, config_exec_stream) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kCUDA), PRECISION(kFloat)}});
std::shared_ptr<cudaStream_t> exec_stream = std::make_shared<cudaStream_t>();
lite::TargetWrapperCuda::CreateStream(exec_stream.get());
config.set_exec_stream(exec_stream);
cudaStream_t stream;
lite::TargetWrapperCuda::CreateStream(&stream);
config.set_cuda_stream(&stream);
RunModel(config);
}
TEST(Resnet50, config_io_stream) {
TEST(Resnet50, config_all_stream) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kCUDA), PRECISION(kFloat)}});
std::shared_ptr<cudaStream_t> io_stream = std::make_shared<cudaStream_t>();
lite::TargetWrapperCuda::CreateStream(io_stream.get());
config.set_io_stream(io_stream);
cudaStream_t exec_stream;
lite::TargetWrapperCuda::CreateStream(&exec_stream);
cudaStream_t io_stream;
lite::TargetWrapperCuda::CreateStream(&io_stream);
config.set_cuda_stream(&exec_stream, &io_stream);
RunModel(config);
}
TEST(Resnet50, config_all_stream) {
TEST(Resnet50, config_multi_exec_stream) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kCUDA), PRECISION(kFloat)}});
std::shared_ptr<cudaStream_t> exec_stream = std::make_shared<cudaStream_t>();
lite::TargetWrapperCuda::CreateStream(exec_stream.get());
config.set_exec_stream(exec_stream);
std::shared_ptr<cudaStream_t> io_stream = std::make_shared<cudaStream_t>();
lite::TargetWrapperCuda::CreateStream(io_stream.get());
config.set_io_stream(io_stream);
config.set_cuda_use_multi_stream(true);
RunModel(config);
}
TEST(Resnet50, config_multi_exec_stream) {
TEST(Resnet50, config_error) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kCUDA), PRECISION(kFloat)}});
config.set_multi_stream(true);
config.set_cuda_use_multi_stream(true);
cudaStream_t exec_stream;
lite::TargetWrapperCuda::CreateStream(&exec_stream);
config.set_cuda_stream(&exec_stream);
RunModel(config);
ASSERT_DEATH(RunModel(config), "");
}
} // namespace lite
......
......@@ -9,5 +9,6 @@ nv_library(cuda_blas SRCS blas.cc DEPS ${cuda_deps})
nv_library(nvtx_wrapper SRCS nvtx_wrapper DEPS ${cuda_deps})
lite_cc_library(cuda_context SRCS context.cc DEPS device_info)
lite_cc_library(stream_wrapper SRCS stream_wrapper.cc DEPS target_wrapper_cuda ${cuda_deps})
add_subdirectory(math)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/stream_wrapper.h"
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/target_wrapper.h"
namespace paddle {
namespace lite {
class StreamWrapper {
public:
explicit StreamWrapper(cudaStream_t stream)
: stream_(stream), owner_(false) {}
StreamWrapper() : owner_(true) {
lite::TargetWrapperCuda::CreateStream(&stream_);
}
~StreamWrapper() {
if (owner_) {
lite::TargetWrapperCuda::DestroyStream(stream_);
}
}
cudaStream_t stream() { return stream_; }
bool owner() { return owner_; }
private:
cudaStream_t stream_;
bool owner_;
};
} // namespace lite
} // namespace paddle
......@@ -144,7 +144,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
}
#ifdef LITE_WITH_CUDA
void RuntimeProgram::UpdateContext(cudaStream_t* exec, cudaStream_t* io) {
void RuntimeProgram::UpdateContext(cudaStream_t exec, cudaStream_t io) {
for (auto& inst : instructions_) {
inst.UpdateContext(exec, io);
}
......
......@@ -128,10 +128,10 @@ struct Instruction {
}
}
void Sync() const { kernel_->mutable_context()->As<CUDAContext>().Sync(); }
void UpdateContext(cudaStream_t* exec, cudaStream_t* io) {
void UpdateContext(cudaStream_t exec, cudaStream_t io) {
if (kernel_->target() == TargetType::kCUDA) {
kernel_->mutable_context()->As<CUDAContext>().SetExecStream(*exec);
kernel_->mutable_context()->As<CUDAContext>().SetIoStream(*io);
kernel_->mutable_context()->As<CUDAContext>().SetExecStream(exec);
kernel_->mutable_context()->As<CUDAContext>().SetIoStream(io);
}
}
#endif
......@@ -224,7 +224,7 @@ class LITE_API RuntimeProgram {
#ifdef LITE_WITH_CUDA
// UpdateContext will update the exec stream and io stream of all kernels in
// the program.
void UpdateContext(cudaStream_t* exec, cudaStream_t* io);
void UpdateContext(cudaStream_t exec, cudaStream_t io);
#endif
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册