diff --git a/src/framework/cl/cl_engine.h b/src/framework/cl/cl_engine.h index f9f373b2a74087960b03c55ec922f95f187cfbc4..76d08513aa4301b9aa22b159a70a17b7b0619b92 100644 --- a/src/framework/cl/cl_engine.h +++ b/src/framework/cl/cl_engine.h @@ -114,6 +114,9 @@ class CLEngine { cl_device_id DeviceID(int index = 0) { return devices_[index]; } + std::string GetCLPath() { return cl_path_; } + void setClPath(std::string cl_path) { cl_path_ = cl_path; } + private: CLEngine() { initialized_ = false; } @@ -129,6 +132,7 @@ class CLEngine { cl_int status_; + std::string cl_path_; std::unique_ptr<_cl_program, CLProgramDeleter> program_; // bool SetClContext(); diff --git a/src/framework/cl/cl_scope.h b/src/framework/cl/cl_scope.h index 0965b133e6d8270b7cd6e28c8ed9a33739b2e2a8..c7c06ca75f47cd65d2350dfa6930068aca73ced0 100644 --- a/src/framework/cl/cl_scope.h +++ b/src/framework/cl/cl_scope.h @@ -58,7 +58,8 @@ class CLScope { } auto program = CLEngine::Instance()->CreateProgramWith( - context_.get(), "./cl_kernel/" + file_name); + context_.get(), + CLEngine::Instance()->GetCLPath() + "/cl_kernel/" + file_name); DLOG << " --- begin build program -> " << file_name << " --- "; CLEngine::Instance()->BuildProgram(program.get()); diff --git a/src/io/api_paddle_mobile.cc b/src/io/api_paddle_mobile.cc index 67f255315fa71acbf24f5071735020c0a435ce64..144cf127a44c78279ca1d95815646a4f01fed6bd 100644 --- a/src/io/api_paddle_mobile.cc +++ b/src/io/api_paddle_mobile.cc @@ -29,7 +29,9 @@ PaddleMobilePredictor::PaddleMobilePredictor( template bool PaddleMobilePredictor::Init(const PaddleMobileConfig &config) { paddle_mobile_.reset(new PaddleMobile()); - +#ifdef PADDLE_MOBILE_CL + paddle_mobile_->SetCLPath(config.cl_path); +#endif if (config.memory_pack.from_memory) { DLOG << "load from memory!"; paddle_mobile_->LoadCombinedMemory(config.memory_pack.model_size, diff --git a/src/io/paddle_inference_api.h b/src/io/paddle_inference_api.h index d37895d3aaa108edb1a8956ccbcb91cbe4b97725..3c9ffa00c7e749d1c9d77562b2db0b42ee605164 100644 --- a/src/io/paddle_inference_api.h +++ b/src/io/paddle_inference_api.h @@ -132,6 +132,7 @@ struct PaddleMobileConfig : public PaddlePredictor::Config { int thread_num = 1; std::string prog_file; std::string param_file; + std::string cl_path; struct PaddleModelMemoryPack memory_pack; }; diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 3cd7c38b2b102659739aefc66b4b25f61cc48bcf..921b72520f1905fcdc7b2a0d15ee4ec5d844cda7 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -158,6 +158,13 @@ void PaddleMobile::Predict_To(int end) { } #endif +#ifdef PADDLE_MOBILE_CL +template +void PaddleMobile::SetCLPath(std::string path) { + framework::CLEngine::Instance()->setClPath(path); +} +#endif + template class PaddleMobile; template class PaddleMobile; template class PaddleMobile; diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 0e86fa988fe8a07131d3ea19fe7c606c27d70c2c..1e8f81c51e02ea6bdbdea8694aa62c9c30e6e6a8 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -26,6 +26,9 @@ limitations under the License. */ #include "framework/load_ops.h" #include "framework/loader.h" #include "framework/tensor.h" +#ifdef PADDLE_MOBILE_CL +#include "framework/cl/cl_engine.h" +#endif namespace paddle_mobile { @@ -68,6 +71,11 @@ class PaddleMobile { void Predict_To(int end); #endif +#ifdef PADDLE_MOBILE_CL + public: + void SetCLPath(std::string cl_path); +#endif + private: std::shared_ptr> loader_; std::shared_ptr> executor_; diff --git a/test/net/test_mobilenet_GPU.cpp b/test/net/test_mobilenet_GPU.cpp index a5276d6e521855ad81e6b9e2edb58c271ae713d9..07582e10dd5db8985f87bae215b8cf1808431565 100644 --- a/test/net/test_mobilenet_GPU.cpp +++ b/test/net/test_mobilenet_GPU.cpp @@ -22,7 +22,7 @@ int main() { auto time1 = paddle_mobile::time(); // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", // std::string(g_mobilenet_detect) + "/params", true); - + paddle_mobile.SetCLPath("."); auto isok = paddle_mobile.Load(std::string(g_mobilenet), true); if (isok) { auto time2 = paddle_mobile::time();