diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 7543a9f7a42e129690fe6882b74f9d1d2f3b5368..8bc3c15c6d761234f5f45a645fde81ff97822f03 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -645,15 +645,22 @@ std::vector::Ptype> Executor::Predict( } #ifdef PADDLE_MOBILE_FPGA + template -void Executor::FeedData(const framework::Tensor &t) { - framework::Variable *g_feed_value = program_.scope->Var("feed"); +void Executor::InjectVariable(const framework::Tensor &t, + string var_name) { + framework::Variable *g_feed_value = program_.scope->Var(var_name); framework::Tensor *feed_tensor = g_feed_value->GetMutable(); feed_tensor->Resize(t.dims()); feed_tensor->ShareDataWith(t); }; +template +void Executor::FeedData(const framework::Tensor &t) { + InjectVariable(t, "feed"); +}; + template std::shared_ptr Executor::FetchResult() { std::shared_ptr to_predict_block = diff --git a/src/io/executor.h b/src/io/executor.h index f1f3f9da7ccaaba1832e9b9e17b408118926b23f..bec9f45444a7502c1b6a119f80f55220765efe50 100644 --- a/src/io/executor.h +++ b/src/io/executor.h @@ -30,6 +30,7 @@ limitations under the License. */ #include #include "common/dep_core.h" #endif +using std::string; namespace paddle_mobile { @@ -94,7 +95,9 @@ class Executor { framework::LoDTensor *tensor) const; #ifdef PADDLE_MOBILE_FPGA + public: + void InjectVariable(const framework::Tensor &t, string var_name); void FeedData(const framework::Tensor &t); std::shared_ptr FetchResult(); void Predict_From_To(int start = 0, int end = -1); diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 9056bac20696abe959c20036d2a0ff7c9a218f35..9710a0ec452db3381e051db95d7da81b48f5f154 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -125,9 +125,16 @@ PaddleMobile::~PaddleMobile() { } #ifdef PADDLE_MOBILE_FPGA + +template +void PaddleMobile::InjectVariable(const framework::Tensor &t, + string var_name) { + executor_->InjectVariable(t, var_name); +} + template void PaddleMobile::FeedData(const framework::Tensor &t) { - return executor_->FeedData(t); + executor_->FeedData(t); }; template diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 66d70b87f2795da867b42c8cb58a10c9fe5b35cb..b11f7d7c3b8fc051c1d0da17e769225aab0dc968 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -95,6 +95,7 @@ class PaddleMobile { #ifdef PADDLE_MOBILE_FPGA public: + void InjectVariable(const framework::Tensor &t, string var_name); void FeedData(const framework::Tensor &t); std::shared_ptr FetchResult(); void Predict_From_To(int start = 0, int end = -1);