diff --git a/src/io/api_paddle_mobile.cc b/src/io/api_paddle_mobile.cc index 6f4b548155ca91ab01a6426cca6ba92ce4f9340e..3bf970294d8db2cd64e351163d88ac89fb6343d0 100644 --- a/src/io/api_paddle_mobile.cc +++ b/src/io/api_paddle_mobile.cc @@ -71,7 +71,7 @@ bool PaddleMobilePredictor::Run( } auto input = inputs[0]; - if (input.shape.size() != 4) { + if (input.lod.size() == 0 && input.shape.size() != 4) { LOG(kLOG_ERROR) << "input shape not equal to 4!"; return false; } @@ -81,17 +81,29 @@ bool PaddleMobilePredictor::Run( } // use tensor - framework::DDim ddim = - framework::make_ddim({dims[0], dims[1], dims[2], dims[3]}); + framework::DDim ddim = framework::make_ddim(dims); framework::Tensor input_tensor; - input_tensor.Resize(ddim); + framework::LoDTensor input_lod_tensor; + paddle_mobile::framework::LoD lod{{}}; + for (int i = 0; i < input.lod.size(); ++i) { + lod[0].push_back(input.lod[i]); + } + input_lod_tensor.set_lod(lod); + int input_length = framework::product(ddim); - auto input_ptr = input_tensor.mutable_data(); + if (input.lod.size() > 0) { + input_lod_tensor.Resize(ddim); + memcpy(input_lod_tensor.mutable_data(), + static_cast(input.data.data()), input_length * sizeof(T)); + paddle_mobile_->Predict(input_lod_tensor); + } else { + input_tensor.Resize(ddim); + memcpy(input_tensor.mutable_data(), static_cast(input.data.data()), + input_length * sizeof(T)); + paddle_mobile_->Predict(input_tensor); + } - memcpy(input_ptr, static_cast(input.data.data()), - input_length * sizeof(T)); - paddle_mobile_->Predict(input_tensor); auto output_tensor = paddle_mobile_->Fetch(); if (output_data->empty()) { diff --git a/src/io/paddle_inference_api.h b/src/io/paddle_inference_api.h index 9a0ed823b19ad1ec07c2ecef928b1018c56ee62c..ae7d34bd51dd59de9359a471964647c020e18649 100644 --- a/src/io/paddle_inference_api.h +++ b/src/io/paddle_inference_api.h @@ -90,7 +90,7 @@ struct PaddleTensor { PaddleTensor() = default; std::string name; // variable name. std::vector shape; - // TODO(Superjomn) for LoD support, add a vector> field if needed. + std::vector lod; PaddleBuf data; // blob of data. PaddleDType dtype; kTypeId_t dtypeid;