未验证 提交 c0e45193 编写于 作者: T TianXiaogang 提交者: GitHub

fix: fix omp thread bug (#2371)

* fix: fix omp thread bug

* fix: fix compile android level & test=develop

* fix: fix omp bug

* fix: fix omp bug
test=develop
上级 7e27b7bf
...@@ -34,6 +34,9 @@ endif() ...@@ -34,6 +34,9 @@ endif()
if(NOT DEFINED ANDROID_API_LEVEL) if(NOT DEFINED ANDROID_API_LEVEL)
set(ANDROID_API_LEVEL "23") set(ANDROID_API_LEVEL "23")
if(ARM_TARGET_ARCH_ABI STREQUAL "armv7")
set(ANDROID_API_LEVEL "22")
endif()
endif() endif()
# then check input arm abi # then check input arm abi
......
...@@ -30,6 +30,9 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -30,6 +30,9 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif #endif
auto places = config.valid_places(); auto places = config.valid_places();
raw_predictor_.Build(config, places); raw_predictor_.Build(config, places);
mode_ = config.power_mode();
threads_ = config.threads();
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
...@@ -51,7 +54,12 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { ...@@ -51,7 +54,12 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames(); return raw_predictor_.GetOutputNames();
} }
void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); } void CxxPaddleApiImpl::Run() {
#ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif
raw_predictor_.Run();
}
std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone() { std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
......
...@@ -29,6 +29,9 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) { ...@@ -29,6 +29,9 @@ void LightPredictorImpl::Init(const lite_api::MobileConfig& config) {
config.param_buffer(), config.param_buffer(),
config.model_from_memory(), config.model_from_memory(),
lite_api::LiteModelType::kNaiveBuffer)); lite_api::LiteModelType::kNaiveBuffer));
mode_ = config.power_mode();
threads_ = config.threads();
} }
std::unique_ptr<lite_api::Tensor> LightPredictorImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> LightPredictorImpl::GetInput(int i) {
...@@ -42,7 +45,12 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetOutput( ...@@ -42,7 +45,12 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetOutput(
new lite_api::Tensor(raw_predictor_->GetOutput(i))); new lite_api::Tensor(raw_predictor_->GetOutput(i)));
} }
void LightPredictorImpl::Run() { raw_predictor_->Run(); } void LightPredictorImpl::Run() {
#ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif
raw_predictor_->Run();
}
std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone() { std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone() {
LOG(FATAL) << "The Clone API is not supported in LigthPredictor"; LOG(FATAL) << "The Clone API is not supported in LigthPredictor";
......
...@@ -102,6 +102,10 @@ class LITE_API PaddlePredictor { ...@@ -102,6 +102,10 @@ class LITE_API PaddlePredictor {
bool record_info = false); bool record_info = false);
virtual ~PaddlePredictor() = default; virtual ~PaddlePredictor() = default;
protected:
int threads_{1};
lite_api::PowerMode mode_{lite_api::LITE_POWER_NO_BIND};
}; };
/// Base class for all the configs. /// Base class for all the configs.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册