未验证 提交 75e8a6fc 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Ban feed and fetch op during inference (#2198)

* init: delete feed and fetch op, using zero copy
test=develop

* delete the unused test
test=develop
上级 781d8191
......@@ -42,13 +42,13 @@ void Predictor::SaveModel(const std::string &dir,
}
lite::Tensor *Predictor::GetInput(size_t offset) {
auto *_feed_list = exec_scope_->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto *feed_list = _feed_list->GetMutable<std::vector<lite::Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
CHECK(input_names_.size() > offset)
<< "The network has " << input_names_.size() << " inputs"
<< ", the offset should be less than this.";
auto *in_var = exec_scope_->FindVar(input_names_[offset]);
CHECK(in_var) << "no fatch variable " << input_names_[offset]
<< " in exec_scope";
return in_var->GetMutable<lite::Tensor>();
}
// get inputs names
......@@ -84,18 +84,23 @@ void Predictor::PrepareFeedFetch() {
}
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
CHECK(output_names_.size() > offset)
<< "The network has " << output_names_.size() << " outputs"
<< ", the offset should be less than this.";
const std::string name = output_names_.at(offset);
auto *out_var = exec_scope_->FindVar(name);
CHECK(out_var) << "no fatch variable " << name << " in exec_scope";
return out_var->GetMutable<lite::Tensor>();
}
const std::vector<lite::Tensor> *Predictor::GetOutputs() const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
return &fetch_list;
std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
std::vector<const lite::Tensor *> outputs;
size_t out_size = output_names_.size();
for (size_t i = 0; i < out_size; i++) {
const std::string name = output_names_.at(i);
outputs.push_back(GetTensor(name));
}
return outputs;
}
const cpp::ProgramDesc &Predictor::program_desc() const {
......@@ -169,6 +174,7 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
factor.ConsiderDataLayout();
optimizer_.Run(std::move(program), inner_places, factor, passes);
exec_scope_ = optimizer_.exec_scope();
PrepareFeedFetch();
}
void Predictor::GenRuntimeProgram() {
......
......@@ -80,7 +80,7 @@ class LITE_API Predictor {
// Get offset-th col of fetch results.
const lite::Tensor* GetOutput(size_t offset) const;
const std::vector<lite::Tensor>* GetOutputs() const;
std::vector<const lite::Tensor*> GetOutputs() const;
const cpp::ProgramDesc& program_desc() const;
const lite::Tensor* GetTensor(const std::string& name) const;
......
......@@ -63,7 +63,6 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#endif
auto places = config.valid_places();
raw_predictor_.Build(config, places);
raw_predictor_.PrepareFeedFetch();
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......
......@@ -41,16 +41,17 @@ void LightPredictor::Build(const std::string& model_dir,
LOG(FATAL) << "Unknown model type";
}
BuildRuntimeProgram(cpp_program_desc_);
PrepareFeedFetch();
}
Tensor* LightPredictor::GetInput(size_t offset) {
auto* _feed_list = program_->exec_scope()->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto* feed_list = _feed_list->GetMutable<std::vector<Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
CHECK(input_names_.size() > offset)
<< "The network has " << input_names_.size() << " inputs"
<< ", the offset should be less than this.";
auto* in_var = program_->exec_scope()->FindVar(input_names_[offset]);
CHECK(in_var) << "no fatch variable " << input_names_[offset]
<< " in exec_scope";
return in_var->GetMutable<lite::Tensor>();
}
// get input by name
......@@ -69,11 +70,13 @@ Tensor* LightPredictor::GetInputByName(const std::string& name) {
}
const Tensor* LightPredictor::GetOutput(size_t offset) {
auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
CHECK(output_names_.size() > offset)
<< "The network has " << output_names_.size() << " outputs"
<< ", the offset should be less than this.";
auto* out_var = program_->exec_scope()->FindVar(output_names_.at(offset));
CHECK(out_var) << "no fatch variable " << output_names_.at(offset)
<< " in exec_scope";
return out_var->GetMutable<lite::Tensor>();
}
// get inputs names
std::vector<std::string> LightPredictor::GetInputNames() {
......
......@@ -53,7 +53,6 @@ void LightPredictorImpl::Init(const MobileConfig& config) {
config.param_buffer(),
config.model_from_memory(),
LiteModelType::kNaiveBuffer));
raw_predictor_->PrepareFeedFetch();
}
std::unique_ptr<Tensor> LightPredictorImpl::GetInput(int i) {
......
......@@ -36,7 +36,6 @@ TEST(LightAPI, load) {
data[i] = i;
}
predictor.PrepareFeedFetch();
std::vector<std::string> inputs = predictor.GetInputNames();
LOG(INFO) << "input size: " << inputs.size();
for (int i = 0; i < inputs.size(); i++) {
......
......@@ -58,11 +58,11 @@ TEST(model, test) {
for (int i = 0; i < FLAGS_repeats; ++i) {
predictor.Run();
}
auto* output_tensors = predictor.GetOutputs();
auto output_tensors = predictor.GetOutputs();
LOG(INFO) << "======output:========";
for (auto t : *output_tensors) {
LOG(INFO) << t;
for (auto* t : output_tensors) {
LOG(INFO) << *t;
}
LOG(INFO)
<< "=====RUN_finished!!============= Speed Report ===================";
......
......@@ -43,16 +43,16 @@ const int8_t *Tensor::data() const {
}
template <>
int *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<int>();
int *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int>(type);
}
template <>
float *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<float>();
float *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<float>(type);
}
template <>
int8_t *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<int8_t>();
int8_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int8_t>(type);
}
shape_t Tensor::shape() const {
......
......@@ -43,7 +43,7 @@ struct LITE_API Tensor {
const T* data() const;
template <typename T>
T* mutable_data() const;
T* mutable_data(TargetType type = TargetType::kHost) const;
/// Shape of the tensor.
shape_t shape() const;
......
......@@ -37,12 +37,6 @@ void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale);
template <typename T>
void scale(int num, const T* in, T* out, float scale, cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, float scale);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -113,6 +113,8 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
void RuntimeProgram::Run() {
for (auto& inst : instructions_) {
std::string op_type = inst.op()->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
inst.Run();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册