diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 32a691b81ffc0586a07f4f06d2114fa5da2e18e2..0cf7f27b7cbe7e009246dc4b939da9b9c94f7bc9 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/platform/profiler.h" @@ -57,6 +58,27 @@ std::string num2str(T a) { } } // namespace +void NativePaddlePredictor::PrepareFeedFetch() { + for (auto *op : inference_program_->Block(0).AllOps()) { + if (op->Type() == "feed") { + int idx = boost::get(op->GetAttr("col")); + if (feeds_.size() <= idx) { + feeds_.resize(idx + 1); + } + feeds_[idx] = op; + feed_names_[op->Output("Out")[0]] = idx; + LOG(ERROR) << "feed " << idx << " " << op->Output("Out")[0]; + } else if (op->Type() == "fetch") { + int idx = boost::get(op->GetAttr("col")); + if (fetchs_.size() <= idx) { + fetchs_.resize(idx + 1); + } + fetchs_[idx] = op; + LOG(ERROR) << "fetch " << idx << " " << op->Input("X")[0]; + } + } +} + bool NativePaddlePredictor::Init( std::shared_ptr parent_scope) { VLOG(3) << "Predictor::init()"; @@ -108,8 +130,7 @@ bool NativePaddlePredictor::Init( sub_scope_ ? sub_scope_ : scope_.get(), 0); // Get the feed_target_names and fetch_target_names - feed_target_names_ = inference_program_->GetFeedTargetNames(); - fetch_target_names_ = inference_program_->GetFetchTargetNames(); + PrepareFeedFetch(); return true; } @@ -130,36 +151,21 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, Timer timer; timer.tic(); // set feed variable - std::map feed_targets; std::vector feeds; - if (!SetFeed(inputs, &feeds)) { + framework::Scope *scope = sub_scope_ != nullptr ? sub_scope_ : scope_.get(); + if (!SetFeed(inputs, scope)) { LOG(ERROR) << "fail to set feed"; return false; } - for (size_t i = 0; i < feed_target_names_.size(); ++i) { - if (config_.specify_input_name) { - feed_targets[inputs[i].name] = &feeds[i]; - } else { - feed_targets[feed_target_names_[i]] = &feeds[i]; - } - } - // get fetch variable - std::map fetch_targets; - std::vector fetchs; - fetchs.resize(fetch_target_names_.size()); - for (size_t i = 0; i < fetch_target_names_.size(); ++i) { - fetch_targets[fetch_target_names_[i]] = &fetchs[i]; - } // Run the inference program // if share variables, we need not create variables VLOG(4) << "Run prepared context"; - executor_->RunPreparedContext( - ctx_.get(), sub_scope_ != nullptr ? sub_scope_ : scope_.get(), - &feed_targets, &fetch_targets, - false, /* don't create local scope each time*/ - false /* don't create variable eatch time */); + executor_->RunPreparedContext(ctx_.get(), scope, + false, /* don't create local scope each time*/ + false /* don't create variable eatch time */); VLOG(4) << "Finish prepared context"; - if (!GetFetch(fetchs, output_data)) { + // get fetch variable + if (!GetFetch(output_data, scope)) { LOG(ERROR) << "fail to get fetches"; return false; } @@ -180,13 +186,13 @@ std::unique_ptr NativePaddlePredictor::Clone() { } bool NativePaddlePredictor::SetFeed(const std::vector &inputs, - std::vector *feeds) { + framework::Scope *scope) { VLOG(3) << "Predictor::set_feed"; - if (inputs.size() != feed_target_names_.size()) { + if (inputs.size() != feeds_.size()) { LOG(ERROR) << "wrong feed input size."; return false; } - for (size_t i = 0; i < feed_target_names_.size(); ++i) { + for (size_t i = 0; i < inputs.size(); ++i) { framework::LoDTensor input; framework::DDim ddim = framework::make_ddim(inputs[i].shape); void *input_ptr; @@ -208,29 +214,40 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, lod.emplace_back(level); } input.set_lod(lod); - - feeds->push_back(input); + int idx = -1; + if (config_.specify_input_name) { + idx = + boost::get(feeds_[feed_names_[inputs[i].name]]->GetAttr("col")); + } else { + idx = boost::get(feeds_[i]->GetAttr("col")); + } + framework::SetFeedVariable(scope, input, "feed", idx); } return true; } -bool NativePaddlePredictor::GetFetch( - const std::vector &fetchs, - std::vector *outputs) { +bool NativePaddlePredictor::GetFetch(std::vector *outputs, + framework::Scope *scope) { VLOG(3) << "Predictor::get_fetch"; - outputs->resize(fetchs.size()); - for (size_t i = 0; i < fetchs.size(); ++i) { + outputs->resize(fetchs_.size()); + for (size_t i = 0; i < fetchs_.size(); ++i) { + std::string fetch_target_name = fetchs_[i]->Input("X")[0]; + int idx = boost::get(fetchs_[i]->GetAttr("col")); + PADDLE_ENFORCE(idx == i); + framework::LoDTensor &output = + framework::GetFetchVariable(*scope, "fetch", idx); // TODO(panyx0718): Support fetch of other types. - if (fetchs[i].type() != typeid(float)) { + if (output.type() != typeid(float)) { LOG(ERROR) << "only support fetching float now."; return false; } + std::vector shape; - auto dims_i = fetchs[i].dims(); - auto lod = fetchs[i].lod(); - const float *output_ptr = fetchs[i].data(); + auto dims_i = output.dims(); + auto lod = output.lod(); + const float *output_ptr = output.data(); // const int64_t* output_ptr = fetchs[i].data(); - auto num = fetchs[i].numel(); + auto num = output.numel(); std::vector data; if (0 == lod.size()) { std::copy(output_ptr, output_ptr + num, std::back_inserter(data)); @@ -275,7 +292,7 @@ bool NativePaddlePredictor::GetFetch( } std::memcpy(buffer.data(), data.data(), buffer.length()); // copy LoD - for (const auto &level : fetchs[i].lod()) { + for (const auto &level : output.lod()) { outputs->at(i).lod.emplace_back(level); } outputs->at(i).dtype = PaddleDType::FLOAT32; diff --git a/paddle/fluid/inference/api/api_impl.h b/paddle/fluid/inference/api/api_impl.h index 4f28c3cd34bade4189871210e6168c6c1c610c2c..4eff9204eba987aed11e4066fa7b6f6cc610a763 100644 --- a/paddle/fluid/inference/api/api_impl.h +++ b/paddle/fluid/inference/api/api_impl.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -47,9 +48,11 @@ class NativePaddlePredictor : public PaddlePredictor { protected: bool SetFeed(const std::vector &input_datas, - std::vector *feeds); - bool GetFetch(const std::vector &fetchs, - std::vector *output_data); + framework::Scope *scope); + bool GetFetch(std::vector *output_data, + framework::Scope *scope); + + void PrepareFeedFetch(); NativeConfig config_; platform::Place place_; @@ -57,8 +60,9 @@ class NativePaddlePredictor : public PaddlePredictor { std::shared_ptr scope_; std::unique_ptr ctx_; std::unique_ptr inference_program_; - std::vector feed_target_names_; - std::vector fetch_target_names_; + std::vector feeds_; + std::map feed_names_; + std::vector fetchs_; // Do not use unique_ptr, use parent scope to delete framework::Scope *sub_scope_{nullptr}; }; diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 93de7a5209e7dc289b4b02e73ef3bb20bfc8c774..abee375313850f1490bacec11f737706c061a5e9 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -74,10 +74,8 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { VLOG(5) << "to create variables"; executor_->CreateVariables(*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0); - // Get the feed_target_names and fetch_target_names - feed_target_names_ = inference_program_->GetFeedTargetNames(); - fetch_target_names_ = inference_program_->GetFetchTargetNames(); + PrepareFeedFetch(); return true; } diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index e2a3e9d46ef9f303d191d59253ffbe9f4826184b..cbcfc964c91c33ab41a72ad7fec759086ad887cc 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/framework/feed_fetch_method.h" + DEFINE_string(model_path, "", "Directory of the inference model."); DEFINE_string(data_file, "", "File of input index data."); DEFINE_int32(repeat, 100, "Running the inference program repeat times"); @@ -124,14 +126,35 @@ void ThreadRunInfer( std::map feed_targets; PADDLE_ENFORCE_EQ(feed_target_names.size(), 1UL); + // map the data of feed_targets to feed_holder + for (auto* op : inference_program->Block(0).AllOps()) { + if (op->Type() == "feed") { + std::string feed_target_name = op->Output("Out")[0]; + int idx = boost::get(op->GetAttr("col")); + paddle::framework::SetFeedVariable(scope, *feed_targets[feed_target_name], + "feed", idx); + } + } + auto& inputs = jobs[tid]; auto start_ms = GetCurrentMs(); for (size_t i = 0; i < inputs.size(); ++i) { feed_targets[feed_target_names[0]] = inputs[i]; - executor.RunPreparedContext(ctx.get(), &sub_scope, &feed_targets, - &fetch_targets, false /*create_local_scope*/); + executor.RunPreparedContext(ctx.get(), &sub_scope, + false /*create_local_scope*/); } auto stop_ms = GetCurrentMs(); + + // obtain the data of fetch_targets from fetch_holder + for (auto* op : inference_program->Block(0).AllOps()) { + if (op->Type() == "fetch") { + std::string fetch_target_name = op->Input("X")[0]; + int idx = boost::get(op->GetAttr("col")); + *fetch_targets[fetch_target_name] = + paddle::framework::GetFetchVariable(*scope, "fetch", idx); + } + } + scope->DeleteScope(&sub_scope); LOG(INFO) << "Tid: " << tid << ", process " << inputs.size() << " samples, avg time per sample: "