提交 ad541652 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Ban feed and fetch op during inference (#2198)

* init: delete feed and fetch op, using zero copy
test=develop

* delete the unused test
test=develop
上级 eb7c5829
...@@ -42,13 +42,13 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -42,13 +42,13 @@ void Predictor::SaveModel(const std::string &dir,
} }
lite::Tensor *Predictor::GetInput(size_t offset) { lite::Tensor *Predictor::GetInput(size_t offset) {
auto *_feed_list = exec_scope_->FindVar("feed"); CHECK(input_names_.size() > offset)
CHECK(_feed_list) << "no feed variable in exec_scope"; << "The network has " << input_names_.size() << " inputs"
auto *feed_list = _feed_list->GetMutable<std::vector<lite::Tensor>>(); << ", the offset should be less than this.";
if (offset >= feed_list->size()) { auto *in_var = exec_scope_->FindVar(input_names_[offset]);
feed_list->resize(offset + 1); CHECK(in_var) << "no fatch variable " << input_names_[offset]
} << " in exec_scope";
return &feed_list->at(offset); return in_var->GetMutable<lite::Tensor>();
} }
// get inputs names // get inputs names
...@@ -84,18 +84,23 @@ void Predictor::PrepareFeedFetch() { ...@@ -84,18 +84,23 @@ void Predictor::PrepareFeedFetch() {
} }
const lite::Tensor *Predictor::GetOutput(size_t offset) const { const lite::Tensor *Predictor::GetOutput(size_t offset) const {
auto *_fetch_list = exec_scope_->FindVar("fetch"); CHECK(output_names_.size() > offset)
CHECK(_fetch_list) << "no fatch variable in exec_scope"; << "The network has " << output_names_.size() << " outputs"
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>(); << ", the offset should be less than this.";
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; const std::string name = output_names_.at(offset);
return &fetch_list.at(offset); auto *out_var = exec_scope_->FindVar(name);
CHECK(out_var) << "no fatch variable " << name << " in exec_scope";
return out_var->GetMutable<lite::Tensor>();
} }
const std::vector<lite::Tensor> *Predictor::GetOutputs() const { std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
auto *_fetch_list = exec_scope_->FindVar("fetch"); std::vector<const lite::Tensor *> outputs;
CHECK(_fetch_list) << "no fatch variable in exec_scope"; size_t out_size = output_names_.size();
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>(); for (size_t i = 0; i < out_size; i++) {
return &fetch_list; const std::string name = output_names_.at(i);
outputs.push_back(GetTensor(name));
}
return outputs;
} }
const cpp::ProgramDesc &Predictor::program_desc() const { const cpp::ProgramDesc &Predictor::program_desc() const {
...@@ -169,6 +174,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -169,6 +174,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
factor.ConsiderDataLayout(); factor.ConsiderDataLayout();
optimizer_.Run(std::move(program), inner_places, factor, passes); optimizer_.Run(std::move(program), inner_places, factor, passes);
exec_scope_ = optimizer_.exec_scope(); exec_scope_ = optimizer_.exec_scope();
PrepareFeedFetch();
} }
void Predictor::GenRuntimeProgram() { void Predictor::GenRuntimeProgram() {
......
...@@ -80,7 +80,7 @@ class LITE_API Predictor { ...@@ -80,7 +80,7 @@ class LITE_API Predictor {
// Get offset-th col of fetch results. // Get offset-th col of fetch results.
const lite::Tensor* GetOutput(size_t offset) const; const lite::Tensor* GetOutput(size_t offset) const;
const std::vector<lite::Tensor>* GetOutputs() const; std::vector<const lite::Tensor*> GetOutputs() const;
const cpp::ProgramDesc& program_desc() const; const cpp::ProgramDesc& program_desc() const;
const lite::Tensor* GetTensor(const std::string& name) const; const lite::Tensor* GetTensor(const std::string& name) const;
......
...@@ -63,7 +63,6 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -63,7 +63,6 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif #endif
auto places = config.valid_places(); auto places = config.valid_places();
raw_predictor_.Build(config, places); raw_predictor_.Build(config, places);
raw_predictor_.PrepareFeedFetch();
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......
...@@ -41,16 +41,17 @@ void LightPredictor::Build(const std::string& model_dir, ...@@ -41,16 +41,17 @@ void LightPredictor::Build(const std::string& model_dir,
LOG(FATAL) << "Unknown model type"; LOG(FATAL) << "Unknown model type";
} }
BuildRuntimeProgram(cpp_program_desc_); BuildRuntimeProgram(cpp_program_desc_);
PrepareFeedFetch();
} }
Tensor* LightPredictor::GetInput(size_t offset) { Tensor* LightPredictor::GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed"); CHECK(input_names_.size() > offset)
CHECK(_feed_list) << "no feed variable in exec_scope"; << "The network has " << input_names_.size() << " inputs"
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>(); << ", the offset should be less than this.";
if (offset >= feed_list->size()) { auto* in_var = program_->exec_scope()->FindVar(input_names_[offset]);
feed_list->resize(offset + 1); CHECK(in_var) << "no fatch variable " << input_names_[offset]
} << " in exec_scope";
return &feed_list->at(offset); return in_var->GetMutable<lite::Tensor>();
} }
// get input by name // get input by name
...@@ -69,11 +70,13 @@ Tensor* LightPredictor::GetInputByName(const std::string& name) { ...@@ -69,11 +70,13 @@ Tensor* LightPredictor::GetInputByName(const std::string& name) {
} }
const Tensor* LightPredictor::GetOutput(size_t offset) { const Tensor* LightPredictor::GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); CHECK(output_names_.size() > offset)
CHECK(_fetch_list) << "no fatch variable in exec_scope"; << "The network has " << output_names_.size() << " outputs"
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>(); << ", the offset should be less than this.";
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; auto* out_var = program_->exec_scope()->FindVar(output_names_.at(offset));
return &fetch_list.at(offset); CHECK(out_var) << "no fatch variable " << output_names_.at(offset)
<< " in exec_scope";
return out_var->GetMutable<lite::Tensor>();
} }
// get inputs names // get inputs names
std::vector<std::string> LightPredictor::GetInputNames() { std::vector<std::string> LightPredictor::GetInputNames() {
......
...@@ -53,7 +53,6 @@ void LightPredictorImpl::Init(const MobileConfig& config) { ...@@ -53,7 +53,6 @@ void LightPredictorImpl::Init(const MobileConfig& config) {
config.param_buffer(), config.param_buffer(),
config.model_from_memory(), config.model_from_memory(),
LiteModelType::kNaiveBuffer)); LiteModelType::kNaiveBuffer));
raw_predictor_->PrepareFeedFetch();
} }
std::unique_ptr<Tensor> LightPredictorImpl::GetInput(int i) { std::unique_ptr<Tensor> LightPredictorImpl::GetInput(int i) {
......
...@@ -36,7 +36,6 @@ TEST(LightAPI, load) { ...@@ -36,7 +36,6 @@ TEST(LightAPI, load) {
data[i] = i; data[i] = i;
} }
predictor.PrepareFeedFetch();
std::vector<std::string> inputs = predictor.GetInputNames(); std::vector<std::string> inputs = predictor.GetInputNames();
LOG(INFO) << "input size: " << inputs.size(); LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
......
...@@ -58,11 +58,11 @@ TEST(model, test) { ...@@ -58,11 +58,11 @@ TEST(model, test) {
for (int i = 0; i < FLAGS_repeats; ++i) { for (int i = 0; i < FLAGS_repeats; ++i) {
predictor.Run(); predictor.Run();
} }
auto* output_tensors = predictor.GetOutputs(); auto output_tensors = predictor.GetOutputs();
LOG(INFO) << "======output:========"; LOG(INFO) << "======output:========";
for (auto t : *output_tensors) { for (auto* t : output_tensors) {
LOG(INFO) << t; LOG(INFO) << *t;
} }
LOG(INFO) LOG(INFO)
<< "=====RUN_finished!!============= Speed Report ==================="; << "=====RUN_finished!!============= Speed Report ===================";
......
...@@ -43,16 +43,16 @@ const int8_t *Tensor::data() const { ...@@ -43,16 +43,16 @@ const int8_t *Tensor::data() const {
} }
template <> template <>
int *Tensor::mutable_data() const { int *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int>(); return tensor(raw_tensor_)->mutable_data<int>(type);
} }
template <> template <>
float *Tensor::mutable_data() const { float *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<float>(); return tensor(raw_tensor_)->mutable_data<float>(type);
} }
template <> template <>
int8_t *Tensor::mutable_data() const { int8_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int8_t>(); return tensor(raw_tensor_)->mutable_data<int8_t>(type);
} }
shape_t Tensor::shape() const { shape_t Tensor::shape() const {
......
...@@ -43,7 +43,7 @@ struct LITE_API Tensor { ...@@ -43,7 +43,7 @@ struct LITE_API Tensor {
const T* data() const; const T* data() const;
template <typename T> template <typename T>
T* mutable_data() const; T* mutable_data(TargetType type = TargetType::kHost) const;
/// Shape of the tensor. /// Shape of the tensor.
shape_t shape() const; shape_t shape() const;
......
...@@ -37,12 +37,6 @@ void scale(int num, const T* in, T* out, float scale, cudaStream_t stream); ...@@ -37,12 +37,6 @@ void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
template <typename T> template <typename T>
void scale(int num, const T* in, T* out, float scale); void scale(int num, const T* in, T* out, float scale);
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale);
} // namespace math } // namespace math
} // namespace cuda } // namespace cuda
} // namespace lite } // namespace lite
......
...@@ -113,6 +113,8 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -113,6 +113,8 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
void RuntimeProgram::Run() { void RuntimeProgram::Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
std::string op_type = inst.op()->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
inst.Run(); inst.Run();
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE #ifdef LITE_WITH_PRECISION_PROFILE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册