From c9c1d31979e0ba087e81462ee085aebd330e2fae Mon Sep 17 00:00:00 2001 From: yangfei Date: Mon, 5 Nov 2018 21:21:21 +0800 Subject: [PATCH] fix cl path problem --- src/framework/cl/cl_engine.h | 4 ++++ src/framework/cl/cl_scope.h | 3 ++- src/io/api_paddle_mobile.cc | 4 +++- src/io/paddle_inference_api.h | 1 + src/io/paddle_mobile.cpp | 7 +++++++ src/io/paddle_mobile.h | 8 ++++++++ test/net/test_mobilenet_GPU.cpp | 2 +- 7 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/framework/cl/cl_engine.h b/src/framework/cl/cl_engine.h index f9f373b2a7..76d08513aa 100644 --- a/src/framework/cl/cl_engine.h +++ b/src/framework/cl/cl_engine.h @@ -114,6 +114,9 @@ class CLEngine { cl_device_id DeviceID(int index = 0) { return devices_[index]; } + std::string GetCLPath() { return cl_path_; } + void setClPath(std::string cl_path) { cl_path_ = cl_path; } + private: CLEngine() { initialized_ = false; } @@ -129,6 +132,7 @@ class CLEngine { cl_int status_; + std::string cl_path_; std::unique_ptr<_cl_program, CLProgramDeleter> program_; // bool SetClContext(); diff --git a/src/framework/cl/cl_scope.h b/src/framework/cl/cl_scope.h index 0965b133e6..c7c06ca75f 100644 --- a/src/framework/cl/cl_scope.h +++ b/src/framework/cl/cl_scope.h @@ -58,7 +58,8 @@ class CLScope { } auto program = CLEngine::Instance()->CreateProgramWith( - context_.get(), "./cl_kernel/" + file_name); + context_.get(), + CLEngine::Instance()->GetCLPath() + "/cl_kernel/" + file_name); DLOG << " --- begin build program -> " << file_name << " --- "; CLEngine::Instance()->BuildProgram(program.get()); diff --git a/src/io/api_paddle_mobile.cc b/src/io/api_paddle_mobile.cc index 67f255315f..144cf127a4 100644 --- a/src/io/api_paddle_mobile.cc +++ b/src/io/api_paddle_mobile.cc @@ -29,7 +29,9 @@ PaddleMobilePredictor::PaddleMobilePredictor( template bool PaddleMobilePredictor::Init(const PaddleMobileConfig &config) { paddle_mobile_.reset(new PaddleMobile()); - +#ifdef PADDLE_MOBILE_CL + paddle_mobile_->SetCLPath(config.cl_path); +#endif if (config.memory_pack.from_memory) { DLOG << "load from memory!"; paddle_mobile_->LoadCombinedMemory(config.memory_pack.model_size, diff --git a/src/io/paddle_inference_api.h b/src/io/paddle_inference_api.h index d37895d3aa..3c9ffa00c7 100644 --- a/src/io/paddle_inference_api.h +++ b/src/io/paddle_inference_api.h @@ -132,6 +132,7 @@ struct PaddleMobileConfig : public PaddlePredictor::Config { int thread_num = 1; std::string prog_file; std::string param_file; + std::string cl_path; struct PaddleModelMemoryPack memory_pack; }; diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 3cd7c38b2b..921b72520f 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -158,6 +158,13 @@ void PaddleMobile::Predict_To(int end) { } #endif +#ifdef PADDLE_MOBILE_CL +template +void PaddleMobile::SetCLPath(std::string path) { + framework::CLEngine::Instance()->setClPath(path); +} +#endif + template class PaddleMobile; template class PaddleMobile; template class PaddleMobile; diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 0e86fa988f..1e8f81c51e 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -26,6 +26,9 @@ limitations under the License. */ #include "framework/load_ops.h" #include "framework/loader.h" #include "framework/tensor.h" +#ifdef PADDLE_MOBILE_CL +#include "framework/cl/cl_engine.h" +#endif namespace paddle_mobile { @@ -68,6 +71,11 @@ class PaddleMobile { void Predict_To(int end); #endif +#ifdef PADDLE_MOBILE_CL + public: + void SetCLPath(std::string cl_path); +#endif + private: std::shared_ptr> loader_; std::shared_ptr> executor_; diff --git a/test/net/test_mobilenet_GPU.cpp b/test/net/test_mobilenet_GPU.cpp index a5276d6e52..07582e10dd 100644 --- a/test/net/test_mobilenet_GPU.cpp +++ b/test/net/test_mobilenet_GPU.cpp @@ -22,7 +22,7 @@ int main() { auto time1 = paddle_mobile::time(); // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", // std::string(g_mobilenet_detect) + "/params", true); - + paddle_mobile.SetCLPath("."); auto isok = paddle_mobile.Load(std::string(g_mobilenet), true); if (isok) { auto time2 = paddle_mobile::time(); -- GitLab