未验证 提交 27ec5deb 编写于 作者: X xiaogang 提交者: GitHub

fix: fix fpga run the feed/fetch op (#2868)

fix fpga lite_tensor compile bug
     add fake quantize_abs_max op
     test=develop
上级 5f8e3f3e
......@@ -121,6 +121,7 @@ void Predictor::SaveOpKernelInfo(const std::string &model_dir) {
<< kpf_path;
}
#ifndef LITE_WITH_FPGA
lite::Tensor *Predictor::GetInput(size_t offset) {
CHECK(input_names_.size() > offset)
<< "The network has " << input_names_.size() << " inputs"
......@@ -130,6 +131,17 @@ lite::Tensor *Predictor::GetInput(size_t offset) {
<< " in exec_scope";
return in_var->GetMutable<lite::Tensor>();
}
#else
lite::Tensor *Predictor::GetInput(size_t offset) {
auto *_feed_list = exec_scope_->FindVar("feed");
CHECK(_feed_list) << "no feed variable in exec_scope";
auto *feed_list = _feed_list->GetMutable<std::vector<lite::Tensor>>();
if (offset >= feed_list->size()) {
feed_list->resize(offset + 1);
}
return &feed_list->at(offset);
}
#endif
// get inputs names
std::vector<std::string> Predictor::GetInputNames() { return input_names_; }
......@@ -167,6 +179,8 @@ void Predictor::PrepareFeedFetch() {
}
}
#ifndef LITE_WITH_FPGA
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
CHECK(output_names_.size() > offset)
<< "The network has " << output_names_.size() << " outputs"
......@@ -186,6 +200,29 @@ std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
}
return outputs;
}
#else
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
return &fetch_list.at(offset);
}
std::vector<const lite::Tensor *> Predictor::GetOutputs() const {
auto *_fetch_list = exec_scope_->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
std::vector<const lite::Tensor *> outputs;
for (auto out : fetch_list) {
outputs.push_back(&out);
}
return outputs;
}
#endif
const cpp::ProgramDesc &Predictor::program_desc() const {
return program_desc_;
......
......@@ -31,11 +31,7 @@ TEST(ResNet50, test) {
std::vector<Place> valid_places(
{Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}});
predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)},
valid_places);
predictor.Build(FLAGS_model_dir, "", "", valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
......@@ -151,6 +151,10 @@ class TensorLite {
size_t offset() const { return offset_; }
bool IsInitialized() const { return buffer_->data(); }
void clear() {
buffer_->Free();
offset_ = 0;
}
// Other share data to this.
void ShareDataWith(const TensorLite &other);
......
......@@ -53,6 +53,11 @@ void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc,
}
valid_kernels_ = op_->CreateKernels(valid_places);
}
void mir::Node::Stmt::ResetKernels(const std::vector<Place> &valid_places) {
CHECK(op_) << "change valid place failed, not created op";
valid_kernels_.clear();
valid_kernels_ = op_->CreateKernels(valid_places);
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) {
auto &x = AsArg();
......
......@@ -53,6 +53,7 @@ class Node {
const std::vector<Place>& valid_places,
lite::Scope* scope = nullptr);
void ResetKernels(const std::vector<Place>& valid_places);
std::string op_type() const { return op_info()->Type(); }
const OpInfo* op_info() const;
OpInfo* mutable_op_info();
......
......@@ -137,11 +137,15 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
void RuntimeProgram::Run() {
for (auto& inst : instructions_) {
#ifndef LITE_WITH_FPGA
if (inst.is_feed_fetch_op()) continue;
#endif
inst.Run();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
#ifndef LITE_WITH_FPGA
LITE_PRECISION_PROFILE(inst)
#endif
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
}
......
......@@ -23,3 +23,5 @@ namespace operators {} // namespace operators
REGISTER_LITE_OP(fake_quantize_range_abs_max,
paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite);
REGISTER_LITE_OP(fake_quantize_abs_max,
paddle::lite::operators::FakeQuantizeRangeMaxAbsOpLite);
......@@ -40,13 +40,15 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto in_scale = op_desc.Input("InScale").front();
if (op_desc.HasInput("InScale")) {
auto in_scale = op_desc.Input("InScale").front();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
}
auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册