From 75e8a6fc4c352ed0a28ed4130b066ef08a182eb8 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 16 Oct 2019 22:21:11 +0800 Subject: [PATCH] Ban feed and fetch op during inference (#2198) * init: delete feed and fetch op, using zero copy test=develop * delete the unused test test=develop --- lite/api/cxx_api.cc | 40 ++++++++++++++++++-------------- lite/api/cxx_api.h | 2 +- lite/api/cxx_api_impl.cc | 1 - lite/api/light_api.cc | 27 +++++++++++---------- lite/api/light_api_impl.cc | 1 - lite/api/light_api_test.cc | 1 - lite/api/model_run_test_image.cc | 6 ++--- lite/api/paddle_api.cc | 12 +++++----- lite/api/paddle_api.h | 2 +- lite/backends/cuda/math/scale.h | 6 ----- lite/core/program.cc | 2 ++ 11 files changed, 51 insertions(+), 49 deletions(-) diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 490a184c2d..502b28d7b4 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -42,13 +42,13 @@ void Predictor::SaveModel(const std::string &dir, } lite::Tensor *Predictor::GetInput(size_t offset) { - auto *_feed_list = exec_scope_->FindVar("feed"); - CHECK(_feed_list) << "no feed variable in exec_scope"; - auto *feed_list = _feed_list->GetMutable>(); - if (offset >= feed_list->size()) { - feed_list->resize(offset + 1); - } - return &feed_list->at(offset); + CHECK(input_names_.size() > offset) + << "The network has " << input_names_.size() << " inputs" + << ", the offset should be less than this."; + auto *in_var = exec_scope_->FindVar(input_names_[offset]); + CHECK(in_var) << "no fatch variable " << input_names_[offset] + << " in exec_scope"; + return in_var->GetMutable(); } // get inputs names @@ -84,18 +84,23 @@ void Predictor::PrepareFeedFetch() { } const lite::Tensor *Predictor::GetOutput(size_t offset) const { - auto *_fetch_list = exec_scope_->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto &fetch_list = *_fetch_list->GetMutable>(); - CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; - return &fetch_list.at(offset); + CHECK(output_names_.size() > offset) + << "The network has " << output_names_.size() << " outputs" + << ", the offset should be less than this."; + const std::string name = output_names_.at(offset); + auto *out_var = exec_scope_->FindVar(name); + CHECK(out_var) << "no fatch variable " << name << " in exec_scope"; + return out_var->GetMutable(); } -const std::vector *Predictor::GetOutputs() const { - auto *_fetch_list = exec_scope_->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto &fetch_list = *_fetch_list->GetMutable>(); - return &fetch_list; +std::vector Predictor::GetOutputs() const { + std::vector outputs; + size_t out_size = output_names_.size(); + for (size_t i = 0; i < out_size; i++) { + const std::string name = output_names_.at(i); + outputs.push_back(GetTensor(name)); + } + return outputs; } const cpp::ProgramDesc &Predictor::program_desc() const { @@ -169,6 +174,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc, factor.ConsiderDataLayout(); optimizer_.Run(std::move(program), inner_places, factor, passes); exec_scope_ = optimizer_.exec_scope(); + PrepareFeedFetch(); } void Predictor::GenRuntimeProgram() { diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 7f5490fa9f..3d8dc2f06a 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -80,7 +80,7 @@ class LITE_API Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; - const std::vector* GetOutputs() const; + std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; const lite::Tensor* GetTensor(const std::string& name) const; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index a92ef0be88..b4fb3828f3 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -63,7 +63,6 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { #endif auto places = config.valid_places(); raw_predictor_.Build(config, places); - raw_predictor_.PrepareFeedFetch(); } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index f23a973c83..12963285e4 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -41,16 +41,17 @@ void LightPredictor::Build(const std::string& model_dir, LOG(FATAL) << "Unknown model type"; } BuildRuntimeProgram(cpp_program_desc_); + PrepareFeedFetch(); } Tensor* LightPredictor::GetInput(size_t offset) { - auto* _feed_list = program_->exec_scope()->FindVar("feed"); - CHECK(_feed_list) << "no feed variable in exec_scope"; - auto* feed_list = _feed_list->GetMutable>(); - if (offset >= feed_list->size()) { - feed_list->resize(offset + 1); - } - return &feed_list->at(offset); + CHECK(input_names_.size() > offset) + << "The network has " << input_names_.size() << " inputs" + << ", the offset should be less than this."; + auto* in_var = program_->exec_scope()->FindVar(input_names_[offset]); + CHECK(in_var) << "no fatch variable " << input_names_[offset] + << " in exec_scope"; + return in_var->GetMutable(); } // get input by name @@ -69,11 +70,13 @@ Tensor* LightPredictor::GetInputByName(const std::string& name) { } const Tensor* LightPredictor::GetOutput(size_t offset) { - auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto& fetch_list = *_fetch_list->GetMutable>(); - CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; - return &fetch_list.at(offset); + CHECK(output_names_.size() > offset) + << "The network has " << output_names_.size() << " outputs" + << ", the offset should be less than this."; + auto* out_var = program_->exec_scope()->FindVar(output_names_.at(offset)); + CHECK(out_var) << "no fatch variable " << output_names_.at(offset) + << " in exec_scope"; + return out_var->GetMutable(); } // get inputs names std::vector LightPredictor::GetInputNames() { diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index c857e377fc..90e1397d83 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -53,7 +53,6 @@ void LightPredictorImpl::Init(const MobileConfig& config) { config.param_buffer(), config.model_from_memory(), LiteModelType::kNaiveBuffer)); - raw_predictor_->PrepareFeedFetch(); } std::unique_ptr LightPredictorImpl::GetInput(int i) { diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc index 0960b6c079..418d97e9e8 100644 --- a/lite/api/light_api_test.cc +++ b/lite/api/light_api_test.cc @@ -36,7 +36,6 @@ TEST(LightAPI, load) { data[i] = i; } - predictor.PrepareFeedFetch(); std::vector inputs = predictor.GetInputNames(); LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { diff --git a/lite/api/model_run_test_image.cc b/lite/api/model_run_test_image.cc index f3cd35c524..7287613a61 100644 --- a/lite/api/model_run_test_image.cc +++ b/lite/api/model_run_test_image.cc @@ -58,11 +58,11 @@ TEST(model, test) { for (int i = 0; i < FLAGS_repeats; ++i) { predictor.Run(); } - auto* output_tensors = predictor.GetOutputs(); + auto output_tensors = predictor.GetOutputs(); LOG(INFO) << "======output:========"; - for (auto t : *output_tensors) { - LOG(INFO) << t; + for (auto* t : output_tensors) { + LOG(INFO) << *t; } LOG(INFO) << "=====RUN_finished!!============= Speed Report ==================="; diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 954d687a48..16ae5db777 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -43,16 +43,16 @@ const int8_t *Tensor::data() const { } template <> -int *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +int *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); } template <> -float *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +float *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); } template <> -int8_t *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +int8_t *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); } shape_t Tensor::shape() const { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 17417aa729..545ae03f67 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -43,7 +43,7 @@ struct LITE_API Tensor { const T* data() const; template - T* mutable_data() const; + T* mutable_data(TargetType type = TargetType::kHost) const; /// Shape of the tensor. shape_t shape() const; diff --git a/lite/backends/cuda/math/scale.h b/lite/backends/cuda/math/scale.h index f59d080795..83af600ba8 100644 --- a/lite/backends/cuda/math/scale.h +++ b/lite/backends/cuda/math/scale.h @@ -37,12 +37,6 @@ void scale(int num, const T* in, T* out, float scale, cudaStream_t stream); template void scale(int num, const T* in, T* out, float scale); -template -void scale(int num, const T* in, T* out, float scale, cudaStream_t stream); - -template -void scale(int num, const T* in, T* out, float scale); - } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/core/program.cc b/lite/core/program.cc index 22a6dbd5dd..f5238f25ed 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -113,6 +113,8 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { void RuntimeProgram::Run() { for (auto& inst : instructions_) { + std::string op_type = inst.op()->op_info()->Type(); + if (op_type == "feed" || op_type == "fetch") continue; inst.Run(); #ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE -- GitLab