未验证 提交 961ad4cf 编写于 作者: Z zp7 提交者: GitHub

paddle_inference_api support lod-tensor input (#1744)

上级 1a49587d
......@@ -71,7 +71,7 @@ bool PaddleMobilePredictor<Device, T>::Run(
}
auto input = inputs[0];
if (input.shape.size() != 4) {
if (input.lod.size() == 0 && input.shape.size() != 4) {
LOG(kLOG_ERROR) << "input shape not equal to 4!";
return false;
}
......@@ -81,17 +81,29 @@ bool PaddleMobilePredictor<Device, T>::Run(
}
// use tensor
framework::DDim ddim =
framework::make_ddim({dims[0], dims[1], dims[2], dims[3]});
framework::DDim ddim = framework::make_ddim(dims);
framework::Tensor input_tensor;
input_tensor.Resize(ddim);
framework::LoDTensor input_lod_tensor;
paddle_mobile::framework::LoD lod{{}};
for (int i = 0; i < input.lod.size(); ++i) {
lod[0].push_back(input.lod[i]);
}
input_lod_tensor.set_lod(lod);
int input_length = framework::product(ddim);
auto input_ptr = input_tensor.mutable_data<T>();
if (input.lod.size() > 0) {
input_lod_tensor.Resize(ddim);
memcpy(input_lod_tensor.mutable_data<T>(),
static_cast<T *>(input.data.data()), input_length * sizeof(T));
paddle_mobile_->Predict(input_lod_tensor);
} else {
input_tensor.Resize(ddim);
memcpy(input_tensor.mutable_data<T>(), static_cast<T *>(input.data.data()),
input_length * sizeof(T));
paddle_mobile_->Predict(input_tensor);
}
memcpy(input_ptr, static_cast<T *>(input.data.data()),
input_length * sizeof(T));
paddle_mobile_->Predict(input_tensor);
auto output_tensor = paddle_mobile_->Fetch();
if (output_data->empty()) {
......
......@@ -90,7 +90,7 @@ struct PaddleTensor {
PaddleTensor() = default;
std::string name; // variable name.
std::vector<int> shape;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
std::vector<int> lod;
PaddleBuf data; // blob of data.
PaddleDType dtype;
kTypeId_t dtypeid;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册