From 041a31a249f865c20a30dce3072572c6fc7d10e5 Mon Sep 17 00:00:00 2001 From: zhangyang Date: Sat, 15 Sep 2018 15:55:43 +0800 Subject: [PATCH] add InjectVariable function for FPGA track --- src/io/executor.cpp | 11 +++++++++-- src/io/executor.h | 3 +++ src/io/paddle_mobile.cpp | 9 ++++++++- src/io/paddle_mobile.h | 1 + 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 7543a9f7a4..8bc3c15c6d 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -645,15 +645,22 @@ std::vector::Ptype> Executor::Predict( } #ifdef PADDLE_MOBILE_FPGA + template -void Executor::FeedData(const framework::Tensor &t) { - framework::Variable *g_feed_value = program_.scope->Var("feed"); +void Executor::InjectVariable(const framework::Tensor &t, + string var_name) { + framework::Variable *g_feed_value = program_.scope->Var(var_name); framework::Tensor *feed_tensor = g_feed_value->GetMutable(); feed_tensor->Resize(t.dims()); feed_tensor->ShareDataWith(t); }; +template +void Executor::FeedData(const framework::Tensor &t) { + InjectVariable(t, "feed"); +}; + template std::shared_ptr Executor::FetchResult() { std::shared_ptr to_predict_block = diff --git a/src/io/executor.h b/src/io/executor.h index f1f3f9da7c..bec9f45444 100644 --- a/src/io/executor.h +++ b/src/io/executor.h @@ -30,6 +30,7 @@ limitations under the License. */ #include #include "common/dep_core.h" #endif +using std::string; namespace paddle_mobile { @@ -94,7 +95,9 @@ class Executor { framework::LoDTensor *tensor) const; #ifdef PADDLE_MOBILE_FPGA + public: + void InjectVariable(const framework::Tensor &t, string var_name); void FeedData(const framework::Tensor &t); std::shared_ptr FetchResult(); void Predict_From_To(int start = 0, int end = -1); diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 9056bac206..9710a0ec45 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -125,9 +125,16 @@ PaddleMobile::~PaddleMobile() { } #ifdef PADDLE_MOBILE_FPGA + +template +void PaddleMobile::InjectVariable(const framework::Tensor &t, + string var_name) { + executor_->InjectVariable(t, var_name); +} + template void PaddleMobile::FeedData(const framework::Tensor &t) { - return executor_->FeedData(t); + executor_->FeedData(t); }; template diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 66d70b87f2..b11f7d7c3b 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -95,6 +95,7 @@ class PaddleMobile { #ifdef PADDLE_MOBILE_FPGA public: + void InjectVariable(const framework::Tensor &t, string var_name); void FeedData(const framework::Tensor &t); std::shared_ptr FetchResult(); void Predict_From_To(int start = 0, int end = -1); -- GitLab