提交 b5359c74 编写于 作者: N nhzlx

use Precision<P>::ptype instead of float

上级 221ce639
...@@ -70,10 +70,11 @@ bool PaddleMobilePredictor<Dtype, P>::Run( ...@@ -70,10 +70,11 @@ bool PaddleMobilePredictor<Dtype, P>::Run(
framework::Tensor input_tensor; framework::Tensor input_tensor;
input_tensor.Resize(ddim); input_tensor.Resize(ddim);
int input_length = framework::product(ddim); int input_length = framework::product(ddim);
auto input_ptr = input_tensor.mutable_data<float>(); typedef typename PrecisionTrait<P>::ptype PType;
auto input_ptr = input_tensor.mutable_data<PType>();
memcpy(input_ptr, static_cast<float *>(input.data.data()), memcpy(input_ptr, static_cast<PType *>(input.data.data()),
input_length * sizeof(float)); input_length * sizeof(PType));
auto output_tensor = paddle_mobile_->Predict(input_tensor); auto output_tensor = paddle_mobile_->Predict(input_tensor);
if (output_data->empty()) { if (output_data->empty()) {
...@@ -90,12 +91,12 @@ bool PaddleMobilePredictor<Dtype, P>::Run( ...@@ -90,12 +91,12 @@ bool PaddleMobilePredictor<Dtype, P>::Run(
output.shape.push_back(static_cast<int>(d)); output.shape.push_back(static_cast<int>(d));
} }
if (output.data.length() < output_length * sizeof(float)) { if (output.data.length() < output_length * sizeof(PType)) {
output.data.Resize(output_length * sizeof(float)); output.data.Resize(output_length * sizeof(PType));
} }
memcpy(output.data.data(), output_tensor->template data<float>(), memcpy(output.data.data(), output_tensor->template data<PType>(),
output_length * sizeof(float)); output_length * sizeof(PType));
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册