From 64777291b97f7eff958fcc58c6c2c9ffc0ff7667 Mon Sep 17 00:00:00 2001 From: zp7 <9678873+ForceDaryl@users.noreply.github.com> Date: Fri, 12 Jul 2019 20:04:11 +0800 Subject: [PATCH] paddle_inference_api support lod-tensor input (#1744) --- src/io/api_paddle_mobile.cc | 28 ++++++++++++++++++++-------- src/io/paddle_inference_api.h | 2 +- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/io/api_paddle_mobile.cc b/src/io/api_paddle_mobile.cc index 6f4b548155..3bf970294d 100644 --- a/src/io/api_paddle_mobile.cc +++ b/src/io/api_paddle_mobile.cc @@ -71,7 +71,7 @@ bool PaddleMobilePredictor::Run( } auto input = inputs[0]; - if (input.shape.size() != 4) { + if (input.lod.size() == 0 && input.shape.size() != 4) { LOG(kLOG_ERROR) << "input shape not equal to 4!"; return false; } @@ -81,17 +81,29 @@ bool PaddleMobilePredictor::Run( } // use tensor - framework::DDim ddim = - framework::make_ddim({dims[0], dims[1], dims[2], dims[3]}); + framework::DDim ddim = framework::make_ddim(dims); framework::Tensor input_tensor; - input_tensor.Resize(ddim); + framework::LoDTensor input_lod_tensor; + paddle_mobile::framework::LoD lod{{}}; + for (int i = 0; i < input.lod.size(); ++i) { + lod[0].push_back(input.lod[i]); + } + input_lod_tensor.set_lod(lod); + int input_length = framework::product(ddim); - auto input_ptr = input_tensor.mutable_data(); + if (input.lod.size() > 0) { + input_lod_tensor.Resize(ddim); + memcpy(input_lod_tensor.mutable_data(), + static_cast(input.data.data()), input_length * sizeof(T)); + paddle_mobile_->Predict(input_lod_tensor); + } else { + input_tensor.Resize(ddim); + memcpy(input_tensor.mutable_data(), static_cast(input.data.data()), + input_length * sizeof(T)); + paddle_mobile_->Predict(input_tensor); + } - memcpy(input_ptr, static_cast(input.data.data()), - input_length * sizeof(T)); - paddle_mobile_->Predict(input_tensor); auto output_tensor = paddle_mobile_->Fetch(); if (output_data->empty()) { diff --git a/src/io/paddle_inference_api.h b/src/io/paddle_inference_api.h index 9a0ed823b1..ae7d34bd51 100644 --- a/src/io/paddle_inference_api.h +++ b/src/io/paddle_inference_api.h @@ -90,7 +90,7 @@ struct PaddleTensor { PaddleTensor() = default; std::string name; // variable name. std::vector shape; - // TODO(Superjomn) for LoD support, add a vector> field if needed. + std::vector lod; PaddleBuf data; // blob of data. PaddleDType dtype; kTypeId_t dtypeid; -- GitLab