提交 9ffd3c9f 编写于 作者: Z zhangyang

add InjectVariable function for FPGA track

上级 758687de
...@@ -645,15 +645,22 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict( ...@@ -645,15 +645,22 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
} }
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) { void Executor<Dtype, P>::InjectVariable(const framework::Tensor &t,
framework::Variable *g_feed_value = program_.scope->Var("feed"); string var_name) {
framework::Variable *g_feed_value = program_.scope->Var(var_name);
framework::Tensor *feed_tensor = framework::Tensor *feed_tensor =
g_feed_value->GetMutable<framework::LoDTensor>(); g_feed_value->GetMutable<framework::LoDTensor>();
feed_tensor->Resize(t.dims()); feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t); feed_tensor->ShareDataWith(t);
}; };
template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) {
InjectVariable(t, "feed");
};
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult() { std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult() {
std::shared_ptr<framework::BlockDesc> to_predict_block = std::shared_ptr<framework::BlockDesc> to_predict_block =
......
...@@ -30,6 +30,7 @@ limitations under the License. */ ...@@ -30,6 +30,7 @@ limitations under the License. */
#include <thread> #include <thread>
#include "common/dep_core.h" #include "common/dep_core.h"
#endif #endif
using std::string;
namespace paddle_mobile { namespace paddle_mobile {
...@@ -94,7 +95,9 @@ class Executor { ...@@ -94,7 +95,9 @@ class Executor {
framework::LoDTensor *tensor) const; framework::LoDTensor *tensor) const;
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
public: public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t); void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult(); std::shared_ptr<framework::Tensor> FetchResult();
void Predict_From_To(int start = 0, int end = -1); void Predict_From_To(int start = 0, int end = -1);
......
...@@ -125,9 +125,16 @@ PaddleMobile<Dtype, P>::~PaddleMobile() { ...@@ -125,9 +125,16 @@ PaddleMobile<Dtype, P>::~PaddleMobile() {
} }
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::InjectVariable(const framework::Tensor &t,
string var_name) {
executor_->InjectVariable(t, var_name);
}
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::FeedData(const framework::Tensor &t) { void PaddleMobile<Dtype, P>::FeedData(const framework::Tensor &t) {
return executor_->FeedData(t); executor_->FeedData(t);
}; };
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
......
...@@ -95,6 +95,7 @@ class PaddleMobile { ...@@ -95,6 +95,7 @@ class PaddleMobile {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
public: public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t); void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult(); std::shared_ptr<framework::Tensor> FetchResult();
void Predict_From_To(int start = 0, int end = -1); void Predict_From_To(int start = 0, int end = -1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册