提交 64777291 编写于 作者: Z zp7 提交者: GitHub

paddle_inference_api support lod-tensor input (#1744)

上级 f6ac7dd2
...@@ -71,7 +71,7 @@ bool PaddleMobilePredictor<Device, T>::Run( ...@@ -71,7 +71,7 @@ bool PaddleMobilePredictor<Device, T>::Run(
} }
auto input = inputs[0]; auto input = inputs[0];
if (input.shape.size() != 4) { if (input.lod.size() == 0 && input.shape.size() != 4) {
LOG(kLOG_ERROR) << "input shape not equal to 4!"; LOG(kLOG_ERROR) << "input shape not equal to 4!";
return false; return false;
} }
...@@ -81,17 +81,29 @@ bool PaddleMobilePredictor<Device, T>::Run( ...@@ -81,17 +81,29 @@ bool PaddleMobilePredictor<Device, T>::Run(
} }
// use tensor // use tensor
framework::DDim ddim = framework::DDim ddim = framework::make_ddim(dims);
framework::make_ddim({dims[0], dims[1], dims[2], dims[3]});
framework::Tensor input_tensor; framework::Tensor input_tensor;
input_tensor.Resize(ddim); framework::LoDTensor input_lod_tensor;
paddle_mobile::framework::LoD lod{{}};
for (int i = 0; i < input.lod.size(); ++i) {
lod[0].push_back(input.lod[i]);
}
input_lod_tensor.set_lod(lod);
int input_length = framework::product(ddim); int input_length = framework::product(ddim);
auto input_ptr = input_tensor.mutable_data<T>(); if (input.lod.size() > 0) {
input_lod_tensor.Resize(ddim);
memcpy(input_lod_tensor.mutable_data<T>(),
static_cast<T *>(input.data.data()), input_length * sizeof(T));
paddle_mobile_->Predict(input_lod_tensor);
} else {
input_tensor.Resize(ddim);
memcpy(input_tensor.mutable_data<T>(), static_cast<T *>(input.data.data()),
input_length * sizeof(T));
paddle_mobile_->Predict(input_tensor);
}
memcpy(input_ptr, static_cast<T *>(input.data.data()),
input_length * sizeof(T));
paddle_mobile_->Predict(input_tensor);
auto output_tensor = paddle_mobile_->Fetch(); auto output_tensor = paddle_mobile_->Fetch();
if (output_data->empty()) { if (output_data->empty()) {
......
...@@ -90,7 +90,7 @@ struct PaddleTensor { ...@@ -90,7 +90,7 @@ struct PaddleTensor {
PaddleTensor() = default; PaddleTensor() = default;
std::string name; // variable name. std::string name; // variable name.
std::vector<int> shape; std::vector<int> shape;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed. std::vector<int> lod;
PaddleBuf data; // blob of data. PaddleBuf data; // blob of data.
PaddleDType dtype; PaddleDType dtype;
kTypeId_t dtypeid; kTypeId_t dtypeid;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册