diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index 490a184c2d74277521eb62a35e626a40872d08b3..502b28d7b4c4e27276d9ac8880c9d46ee25191b1 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -42,13 +42,13 @@ void Predictor::SaveModel(const std::string &dir, } lite::Tensor *Predictor::GetInput(size_t offset) { - auto *_feed_list = exec_scope_->FindVar("feed"); - CHECK(_feed_list) << "no feed variable in exec_scope"; - auto *feed_list = _feed_list->GetMutable>(); - if (offset >= feed_list->size()) { - feed_list->resize(offset + 1); - } - return &feed_list->at(offset); + CHECK(input_names_.size() > offset) + << "The network has " << input_names_.size() << " inputs" + << ", the offset should be less than this."; + auto *in_var = exec_scope_->FindVar(input_names_[offset]); + CHECK(in_var) << "no fatch variable " << input_names_[offset] + << " in exec_scope"; + return in_var->GetMutable(); } // get inputs names @@ -84,18 +84,23 @@ void Predictor::PrepareFeedFetch() { } const lite::Tensor *Predictor::GetOutput(size_t offset) const { - auto *_fetch_list = exec_scope_->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto &fetch_list = *_fetch_list->GetMutable>(); - CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; - return &fetch_list.at(offset); + CHECK(output_names_.size() > offset) + << "The network has " << output_names_.size() << " outputs" + << ", the offset should be less than this."; + const std::string name = output_names_.at(offset); + auto *out_var = exec_scope_->FindVar(name); + CHECK(out_var) << "no fatch variable " << name << " in exec_scope"; + return out_var->GetMutable(); } -const std::vector *Predictor::GetOutputs() const { - auto *_fetch_list = exec_scope_->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto &fetch_list = *_fetch_list->GetMutable>(); - return &fetch_list; +std::vector Predictor::GetOutputs() const { + std::vector outputs; + size_t out_size = output_names_.size(); + for (size_t i = 0; i < out_size; i++) { + const std::string name = output_names_.at(i); + outputs.push_back(GetTensor(name)); + } + return outputs; } const cpp::ProgramDesc &Predictor::program_desc() const { @@ -169,6 +174,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc, factor.ConsiderDataLayout(); optimizer_.Run(std::move(program), inner_places, factor, passes); exec_scope_ = optimizer_.exec_scope(); + PrepareFeedFetch(); } void Predictor::GenRuntimeProgram() { diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 7f5490fa9f82d882fc2353c515c0542a365d5d32..3d8dc2f06aca24e23a77a0b32dc85a0959290758 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -80,7 +80,7 @@ class LITE_API Predictor { // Get offset-th col of fetch results. const lite::Tensor* GetOutput(size_t offset) const; - const std::vector* GetOutputs() const; + std::vector GetOutputs() const; const cpp::ProgramDesc& program_desc() const; const lite::Tensor* GetTensor(const std::string& name) const; diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index a92ef0be88ae53a5479c57f61acc7d2bca14077d..b4fb3828f3b9b38aa3bcefc1df05d6453d55e771 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -63,7 +63,6 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { #endif auto places = config.valid_places(); raw_predictor_.Build(config, places); - raw_predictor_.PrepareFeedFetch(); } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index f23a973c830dc62719f5c4e25b2cd2de294882e8..12963285e482b2ea6c6e761f430699507d45c0c5 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -41,16 +41,17 @@ void LightPredictor::Build(const std::string& model_dir, LOG(FATAL) << "Unknown model type"; } BuildRuntimeProgram(cpp_program_desc_); + PrepareFeedFetch(); } Tensor* LightPredictor::GetInput(size_t offset) { - auto* _feed_list = program_->exec_scope()->FindVar("feed"); - CHECK(_feed_list) << "no feed variable in exec_scope"; - auto* feed_list = _feed_list->GetMutable>(); - if (offset >= feed_list->size()) { - feed_list->resize(offset + 1); - } - return &feed_list->at(offset); + CHECK(input_names_.size() > offset) + << "The network has " << input_names_.size() << " inputs" + << ", the offset should be less than this."; + auto* in_var = program_->exec_scope()->FindVar(input_names_[offset]); + CHECK(in_var) << "no fatch variable " << input_names_[offset] + << " in exec_scope"; + return in_var->GetMutable(); } // get input by name @@ -69,11 +70,13 @@ Tensor* LightPredictor::GetInputByName(const std::string& name) { } const Tensor* LightPredictor::GetOutput(size_t offset) { - auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); - CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto& fetch_list = *_fetch_list->GetMutable>(); - CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; - return &fetch_list.at(offset); + CHECK(output_names_.size() > offset) + << "The network has " << output_names_.size() << " outputs" + << ", the offset should be less than this."; + auto* out_var = program_->exec_scope()->FindVar(output_names_.at(offset)); + CHECK(out_var) << "no fatch variable " << output_names_.at(offset) + << " in exec_scope"; + return out_var->GetMutable(); } // get inputs names std::vector LightPredictor::GetInputNames() { diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc index c857e377fcd62644be7783c2bc7431a52e2af277..90e1397d8338adb1ba732fc322ae03520bcce27f 100644 --- a/lite/api/light_api_impl.cc +++ b/lite/api/light_api_impl.cc @@ -53,7 +53,6 @@ void LightPredictorImpl::Init(const MobileConfig& config) { config.param_buffer(), config.model_from_memory(), LiteModelType::kNaiveBuffer)); - raw_predictor_->PrepareFeedFetch(); } std::unique_ptr LightPredictorImpl::GetInput(int i) { diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc index 0960b6c079f9a833206d5b41526360e14a512116..418d97e9e8814b5e6e90a76cbdb6e92677c9c726 100644 --- a/lite/api/light_api_test.cc +++ b/lite/api/light_api_test.cc @@ -36,7 +36,6 @@ TEST(LightAPI, load) { data[i] = i; } - predictor.PrepareFeedFetch(); std::vector inputs = predictor.GetInputNames(); LOG(INFO) << "input size: " << inputs.size(); for (int i = 0; i < inputs.size(); i++) { diff --git a/lite/api/model_run_test_image.cc b/lite/api/model_run_test_image.cc index f3cd35c524c4cae7f940fa77a7330722230455da..7287613a61d1027cc596f5b306d32178dac67718 100644 --- a/lite/api/model_run_test_image.cc +++ b/lite/api/model_run_test_image.cc @@ -58,11 +58,11 @@ TEST(model, test) { for (int i = 0; i < FLAGS_repeats; ++i) { predictor.Run(); } - auto* output_tensors = predictor.GetOutputs(); + auto output_tensors = predictor.GetOutputs(); LOG(INFO) << "======output:========"; - for (auto t : *output_tensors) { - LOG(INFO) << t; + for (auto* t : output_tensors) { + LOG(INFO) << *t; } LOG(INFO) << "=====RUN_finished!!============= Speed Report ==================="; diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 954d687a481ae70663b77ea51529054af921dc9d..16ae5db7776aeea285906bfcb1d68ae30b68bf12 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -43,16 +43,16 @@ const int8_t *Tensor::data() const { } template <> -int *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +int *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); } template <> -float *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +float *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); } template <> -int8_t *Tensor::mutable_data() const { - return tensor(raw_tensor_)->mutable_data(); +int8_t *Tensor::mutable_data(TargetType type) const { + return tensor(raw_tensor_)->mutable_data(type); } shape_t Tensor::shape() const { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index 17417aa72964dc89346a9af3c1c9a47116dd6cca..545ae03f6725de7649b3278835bda973ade2755e 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -43,7 +43,7 @@ struct LITE_API Tensor { const T* data() const; template - T* mutable_data() const; + T* mutable_data(TargetType type = TargetType::kHost) const; /// Shape of the tensor. shape_t shape() const; diff --git a/lite/backends/cuda/math/scale.h b/lite/backends/cuda/math/scale.h index f59d080795f0da9741c75e26a785c1b1b56e2f9b..83af600ba8c68a236fdb2a5c9f8521199b46f633 100644 --- a/lite/backends/cuda/math/scale.h +++ b/lite/backends/cuda/math/scale.h @@ -37,12 +37,6 @@ void scale(int num, const T* in, T* out, float scale, cudaStream_t stream); template void scale(int num, const T* in, T* out, float scale); -template -void scale(int num, const T* in, T* out, float scale, cudaStream_t stream); - -template -void scale(int num, const T* in, T* out, float scale); - } // namespace math } // namespace cuda } // namespace lite diff --git a/lite/core/program.cc b/lite/core/program.cc index 22a6dbd5dd0c3b3526f1b8fd6ca30d5db0cbb82d..f5238f25ed16f33160123e39c50ba9689c8b6493 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -113,6 +113,8 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { void RuntimeProgram::Run() { for (auto& inst : instructions_) { + std::string op_type = inst.op()->op_info()->Type(); + if (op_type == "feed" || op_type == "fetch") continue; inst.Run(); #ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE