提交 f66e4a5e 编写于 作者: D DannyIsFunny

test=develop

上级 97c10e3e
......@@ -34,38 +34,43 @@ namespace lite {
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config_ = config;
auto places = config.valid_places();
std::vector<std::string> passes{};
if (!status_is_cloned_) {
auto places = config.valid_places();
std::vector<std::string> passes{};
#ifdef LITE_WITH_CUDA
// if kCUDA is included in valid places, it should be initialized first,
// otherwise skip this step.
for (auto &p : places) {
if (p.target == TARGET(kCUDA)) {
Env<TARGET(kCUDA)>::Init();
if (config_.multi_stream()) {
passes = {"multi_stream_analysis_pass"};
VLOG(3) << "add pass: " << passes[0];
// if kCUDA is included in valid places, it should be initialized first,
// otherwise skip this step.
for (auto &p : places) {
if (p.target == TARGET(kCUDA)) {
Env<TARGET(kCUDA)>::Init();
if (config_.multi_stream()) {
passes = {"multi_stream_analysis_pass"};
VLOG(3) << "add pass: " << passes[0];
}
break;
}
break;
}
}
#endif
#ifdef LITE_WITH_MLU
Env<TARGET(kMLU)>::Init();
lite::DeviceInfo::Global().SetMLURunMode(config.mlu_core_version(),
config.mlu_core_number(),
config.mlu_use_first_conv(),
config.mlu_first_conv_mean(),
config.mlu_first_conv_std(),
config.mlu_input_layout());
Env<TARGET(kMLU)>::Init();
lite::DeviceInfo::Global().SetMLURunMode(config.mlu_core_version(),
config.mlu_core_number(),
config.mlu_use_first_conv(),
config.mlu_first_conv_mean(),
config.mlu_first_conv_std(),
config.mlu_input_layout());
#endif // LITE_WITH_MLU
auto use_layout_preprocess_pass =
config.model_dir().find("OPENCL_PRE_PRECESS");
VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass;
if (places[0].target == TARGET(kOpenCL) &&
use_layout_preprocess_pass != std::string::npos) {
passes = {"type_layout_cast_preprocess_pass"};
VLOG(1) << "add pass:" << passes[0];
auto use_layout_preprocess_pass =
config.model_dir().find("OPENCL_PRE_PRECESS");
VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass;
if (places[0].target == TARGET(kOpenCL) &&
use_layout_preprocess_pass != std::string::npos) {
passes = {"type_layout_cast_preprocess_pass"};
VLOG(1) << "add pass:" << passes[0];
}
raw_predictor_->Build(config, places, passes);
} else {
CHECK(raw_predictor_) << "The Predictor can not be nullptr in Clone mode.";
}
mode_ = config.power_mode();
threads_ = config.threads();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册