From 458183af7dcbf08fbc572ce354d1af535f3cc4aa Mon Sep 17 00:00:00 2001 From: zhangyang0701 Date: Fri, 1 Feb 2019 21:31:39 +0800 Subject: [PATCH] support multiple feed for FPGA track --- src/framework/executor.cpp | 69 +++++++++++++++++++++++++- src/framework/executor.h | 3 ++ src/framework/operator.cpp | 15 ++++++ src/framework/operator.h | 4 +- src/framework/program/program_desc.cpp | 3 +- src/framework/scope.cpp | 26 ++++++++++ src/framework/scope.h | 7 +++ src/io/paddle_mobile.cpp | 6 ++- src/io/paddle_mobile.h | 1 + 9 files changed, 130 insertions(+), 4 deletions(-) diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index e5b3fadfed..8b54619ae3 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -92,6 +92,14 @@ Executor::Executor(const Program &program, ops_list_.push_back(op_handler); } } +#ifdef PADDLE_MOBILE_FPGA + TalorFeedOp(); + DLOG << "TalorFeed finished"; + TalorFetchdOp(); + DLOG << "TalorFetch finished"; + program_.scope->print_vars(); + +#endif } template @@ -443,11 +451,54 @@ std::shared_ptr Executor::GetOutput( } #ifdef PADDLE_MOBILE_FPGA +template +void Executor::TalorFeedOp() { + auto &ops = ops_of_block_[0]; + int num = 0; + program_.scope->EraseVars(std::vector{string("feed")}); + for (auto op : ops) { + if (op->Type() == "feed") { + auto new_name = string("feed") + std::to_string(num++); + auto var = program_.scope->Var(new_name); + auto tensor = var->template GetMutable(); + auto output_map = op->Outputs(); + std::vector out_keys = op->GetOutKeys(); + PADDLE_MOBILE_ENFORCE(!out_keys.empty(), "this op contains no output"); + auto output_tensor = + GetVarValue(out_keys[0], output_map, *(program_.scope)); + tensor->Resize(output_tensor->dims()); + tensor->init(typeid(float)); + op->ChangeNameMap("X", std::vector{new_name}); + } + } +} +template +void Executor::TalorFetchdOp() { + auto &ops = ops_of_block_[0]; + int num = 0; + program_.scope->EraseVars(std::vector{string("fetch")}); + for (auto op : ops) { + if (op->Type() == "fetch") { + auto new_name = string("fetch") + std::to_string(num++); + auto var = program_.scope->Var(new_name); + auto tensor = var->template GetMutable(); + auto input_map = op->Inputs(); + std::vector in_keys = op->GetInputKeys(); + PADDLE_MOBILE_ENFORCE(!in_keys.empty(), "this op contains no input"); + auto input_tensor = + GetVarValue(in_keys[0], input_map, *(program_.scope)); + tensor->Resize(input_tensor->dims()); + tensor->init(typeid(float)); + op->ChangeNameMap("Out", std::vector{new_name}); + } + } +} + template void Executor::InjectVariable(const Tensor &t, std::string var_name) { Variable *g_feed_value = program_.scope->Var(var_name); - Tensor *feed_tensor = g_feed_value->GetMutable(); + Tensor *feed_tensor = g_feed_value->template GetMutable(); feed_tensor->Resize(t.dims()); feed_tensor->ShareDataWith(t); } @@ -457,6 +508,22 @@ void Executor::FeedData(const Tensor &t) { InjectVariable(t, "feed"); } +template +void Executor::FeedData(const std::vector &v) { + auto input_size = v.size(); + PADDLE_MOBILE_ENFORCE(input_size > 0, "Empty input"); + int counter = 0; + auto vars = program_.scope->VarContain("feed"); + for (auto var : vars) { + Tensor *feed_tensor = var->template GetMutable(); + feed_tensor->Resize(v[counter].dims()); + feed_tensor->ShareDataWith(v[counter]); + if (++counter > v.size()) { + return; + } + } +} + template std::shared_ptr Executor::FetchResult(int id) { auto &ops = ops_of_block_[0]; diff --git a/src/framework/executor.h b/src/framework/executor.h index dc5a542362..2bce5c39b5 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -50,8 +50,11 @@ class Executor { std::shared_ptr GetOutput(const std::string &var_name); #ifdef PADDLE_MOBILE_FPGA + void TalorFeedOp(); + void TalorFetchdOp(); void InjectVariable(const Tensor &t, std::string var_name); void FeedData(const Tensor &t); + void FeedData(const std::vector &v); std::shared_ptr FetchResult(int id = -1); void Predict_From_To(int start = 0, int end = -1); void Predict_From(int start); diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index 611b134eaa..0d861e542f 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -131,6 +131,21 @@ void OperatorBase::Run() { } #endif +#ifdef PADDLE_MOBILE_FPGA +template +void OperatorBase::ChangeNameMap(string key, std::vector value) { + auto it = inputs_.find(key); + if (it != inputs_.end()) { + inputs_[key] = value; + return; + } + it = outputs_.find(key); + if (it != outputs_.end()) { + inputs_[key] = value; + } +} +#endif + template class OperatorBase; template class OperatorBase; template class OperatorBase; diff --git a/src/framework/operator.h b/src/framework/operator.h index deb573571f..28bef0eec8 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -78,7 +78,9 @@ class OperatorBase { this->scope_->EraseVars(var_names); } } - +#ifdef PADDLE_MOBILE_FPGA + void ChangeNameMap(string key, std::vector value); +#endif protected: std::shared_ptr scope_; std::string type_; diff --git a/src/framework/program/program_desc.cpp b/src/framework/program/program_desc.cpp index 8483e1e5d6..6c203865a5 100644 --- a/src/framework/program/program_desc.cpp +++ b/src/framework/program/program_desc.cpp @@ -53,7 +53,8 @@ void ProgramDesc::Description(std::string header) { } } for (auto &attr : op->GetAttrMap()) { - LOG(kLOG_DEBUG2) << "attr name:: " << attr.first; + if (attr.first == "op_callstack") continue; + LOG(kLOG_DEBUG2) << "attr name: " << attr.first; LOG(kLOG_DEBUG3) << "argument - " << attr.second; } } diff --git a/src/framework/scope.cpp b/src/framework/scope.cpp index a1f5789aa5..db26308144 100644 --- a/src/framework/scope.cpp +++ b/src/framework/scope.cpp @@ -111,5 +111,31 @@ Variable *Scope::FindVarLocally(const std::string &name) const { return nullptr; } +#ifdef PADDLE_MOBILE_FPGA +Variable *Scope::Var(const std::string &name, const int id) { + return Var(name + std::to_string(id)); +} + +std::vector Scope::VarContain(const std::string substring) { + std::vector v; + for (auto pair : vars_) { + if (pair.first.find(substring) == 0) { + v.push_back(pair.second); + } + } + return v; +} + +void Scope::InsertVar(const std::string str, Variable *var) {} + +void Scope::print_vars() { + DLOG << "====================start to print variables================="; + for (auto pair : vars_) { + DLOG << pair.first; + } + DLOG << "==================complete printing variables================"; +} +#endif + } // namespace framework } // namespace paddle_mobile diff --git a/src/framework/scope.h b/src/framework/scope.h index b853b97e59..d9e3a179e0 100644 --- a/src/framework/scope.h +++ b/src/framework/scope.h @@ -83,6 +83,13 @@ class Scope { Variable *FindVarLocally(const std::string &name) const; +#ifdef PADDLE_MOBILE_FPGA + Variable *Var(const std::string &name, const int id); + std::vector VarContain(const std::string substring); + void InsertVar(const std::string str, Variable *var); + void print_vars(); +#endif + #ifdef PADDLE_MOBILE_CL CLScope *GetCLScpoe() { return cl_scope_; } #endif diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 7e8fcb8288..ea76d8a67c 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -227,7 +227,11 @@ template void PaddleMobile::FeedData(const framework::Tensor &t) { executor_->FeedData(t); } - +template +void PaddleMobile::FeedData( + const std::vector &v) { + executor_->FeedData(v); +}; template std::shared_ptr PaddleMobile::FetchResult( int id) { diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 55fcaf3598..02a1ed1b50 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -90,6 +90,7 @@ class PaddleMobile { #ifdef PADDLE_MOBILE_FPGA void InjectVariable(const framework::Tensor &t, std::string var_name); void FeedData(const framework::Tensor &t); + void FeedData(const std::vector &v); std::shared_ptr FetchResult(int id = -1); void Predict_From_To(int start = 0, int end = -1); void Predict_From(int start); -- GitLab