提交 9ffd3c9f 编写于 作者: Z zhangyang

add InjectVariable function for FPGA track

上级 758687de
......@@ -645,15 +645,22 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
}
#ifdef PADDLE_MOBILE_FPGA
template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) {
framework::Variable *g_feed_value = program_.scope->Var("feed");
void Executor<Dtype, P>::InjectVariable(const framework::Tensor &t,
string var_name) {
framework::Variable *g_feed_value = program_.scope->Var(var_name);
framework::Tensor *feed_tensor =
g_feed_value->GetMutable<framework::LoDTensor>();
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
};
template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) {
InjectVariable(t, "feed");
};
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult() {
std::shared_ptr<framework::BlockDesc> to_predict_block =
......
......@@ -30,6 +30,7 @@ limitations under the License. */
#include <thread>
#include "common/dep_core.h"
#endif
using std::string;
namespace paddle_mobile {
......@@ -94,7 +95,9 @@ class Executor {
framework::LoDTensor *tensor) const;
#ifdef PADDLE_MOBILE_FPGA
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult();
void Predict_From_To(int start = 0, int end = -1);
......
......@@ -125,9 +125,16 @@ PaddleMobile<Dtype, P>::~PaddleMobile() {
}
#ifdef PADDLE_MOBILE_FPGA
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::InjectVariable(const framework::Tensor &t,
string var_name) {
executor_->InjectVariable(t, var_name);
}
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::FeedData(const framework::Tensor &t) {
return executor_->FeedData(t);
executor_->FeedData(t);
};
template <typename Dtype, Precision P>
......
......@@ -95,6 +95,7 @@ class PaddleMobile {
#ifdef PADDLE_MOBILE_FPGA
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult();
void Predict_From_To(int start = 0, int end = -1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册