提交 cecd2c12 编写于 作者: S ShiningZhang

server support fp16

上级 cc096ef0
...@@ -130,11 +130,11 @@ int GeneralReaderOp::inference() { ...@@ -130,11 +130,11 @@ int GeneralReaderOp::inference() {
data_len = tensor.tensor_content().size(); data_len = tensor.tensor_content().size();
src_ptr = tensor.tensor_content().data(); src_ptr = tensor.tensor_content().data();
} else if (elem_type == P_FP16) { } else if (elem_type == P_FP16) {
// paddle inference will support FLOAT16 // copy bytes from tensor content to TensorVector
// elem_size = 1; elem_size = 1;
// paddleTensor.dtype = paddle::PaddleDType::FLOAT16; paddleTensor.dtype = paddle::PaddleDType::FLOAT16;
// data_len = tensor.tensor_content().size(); data_len = tensor.tensor_content().size();
// src_ptr = tensor.tensor_content().data(); src_ptr = tensor.tensor_content().data();
} else if (elem_type == P_STRING) { } else if (elem_type == P_STRING) {
// use paddle::PaddleDType::UINT8 as for String. // use paddle::PaddleDType::UINT8 as for String.
elem_size = sizeof(char); elem_size = sizeof(char);
......
...@@ -178,14 +178,12 @@ int GeneralResponseOp::inference() { ...@@ -178,14 +178,12 @@ int GeneralResponseOp::inference() {
VLOG(2) << "(logid=" << log_id << ")Prepare int8 var [" VLOG(2) << "(logid=" << log_id << ")Prepare int8 var ["
<< model_config->_fetch_name[idx] << "]."; << model_config->_fetch_name[idx] << "].";
tensor->set_tensor_content(in->at(idx).data.data(), in->at(idx).data.length()); tensor->set_tensor_content(in->at(idx).data.data(), in->at(idx).data.length());
} } else if (dtype == paddle::PaddleDType::FLOAT16) {
// inference will support fp16 tensor->set_elem_type(5);
// else if (dtype == paddle::PaddleDType::FLOAT16) { VLOG(2) << "(logid=" << log_id << ")Prepare float16 var ["
// tensor->set_elem_type(5); << model_config->_fetch_name[idx] << "].";
// VLOG(2) << "(logid=" << log_id << ")Prepare float16 var [" tensor->set_tensor_content(in->at(idx).data.data(), in->at(idx).data.length());
// << model_config->_fetch_name[idx] << "]."; }
// tensor->set_tensor_content(in->at(idx).data.data(), in->at(idx).data.length());
// }
VLOG(2) << "(logid=" << log_id << ") fetch var [" VLOG(2) << "(logid=" << log_id << ") fetch var ["
<< model_config->_fetch_name[idx] << "] ready"; << model_config->_fetch_name[idx] << "] ready";
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "core/predictor/framework/infer_data.h" #include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "paddle_inference_api.h" // NOLINT #include "paddle_inference_api.h" // NOLINT
#include "experimental/float16.h"
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
namespace predictor { namespace predictor {
...@@ -541,19 +542,17 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> { ...@@ -541,19 +542,17 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
paddle::PaddleDType::INT8) { paddle::PaddleDType::INT8) {
int8_t* data = static_cast<int8_t*>(origin_data); int8_t* data = static_cast<int8_t*>(origin_data);
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
} else if ((*tensorVector_in_pointer)[i].dtype ==
paddle::PaddleDType::FLOAT16) {
paddle::platform::float16* data =
static_cast<paddle::platform::float16*>(origin_data);
lod_tensor_in->CopyFromCpu(data);
} else { } else {
LOG(ERROR) << "Inference not support type[" LOG(ERROR) << "Inference not support type["
<< (*tensorVector_in_pointer)[i].dtype << "],name[" << (*tensorVector_in_pointer)[i].dtype << "],name["
<< (*tensorVector_in_pointer)[i].name << "]" << (*tensorVector_in_pointer)[i].name << "]"
<< " copy into core failed!"; << " copy into core failed!";
} }
// Paddle inference will support FP16 in next version.
// else if ((*tensorVector_in_pointer)[i].dtype ==
// paddle::PaddleDType::FLOAT16) {
// paddle::platform::float16* data =
// static_cast<paddle::platform::float16*>(origin_data);
// lod_tensor_in->CopyFromCpu(data);
// }
VLOG(2) << "Tensor:name=" << (*tensorVector_in_pointer)[i].name VLOG(2) << "Tensor:name=" << (*tensorVector_in_pointer)[i].name
<< ";in_dtype=" << (*tensorVector_in_pointer)[i].dtype << ";in_dtype=" << (*tensorVector_in_pointer)[i].dtype
<< ";tensor_dtype=" << lod_tensor_in->type(); << ";tensor_dtype=" << lod_tensor_in->type();
...@@ -641,20 +640,18 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> { ...@@ -641,20 +640,18 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
int8_t* data_out = reinterpret_cast<int8_t*>(databuf_data); int8_t* data_out = reinterpret_cast<int8_t*>(databuf_data);
lod_tensor_out->CopyToCpu(data_out); lod_tensor_out->CopyToCpu(data_out);
databuf_char = reinterpret_cast<char*>(data_out); databuf_char = reinterpret_cast<char*>(data_out);
} else if (dataType == paddle::PaddleDType::FLOAT16) {
databuf_size = out_num * sizeof(paddle::platform::float16);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
paddle::platform::float16* data_out =
reinterpret_cast<paddle::platform::float16*>(databuf_data);
lod_tensor_out->CopyToCpu(data_out);
databuf_char = reinterpret_cast<char*>(data_out);
} }
// Inference will support FP16 in next version
// else if (dataType == paddle::PaddleDType::FLOAT16) {
// using float16 = paddle::platform::float16;
// databuf_size = out_num * sizeof(float16);
// databuf_data = MempoolWrapper::instance().malloc(databuf_size);
// if (!databuf_data) {
// LOG(ERROR) << "Malloc failed, size: " << databuf_size;
// return -1;
// }
// float16* data_out = reinterpret_cast<float16*>(databuf_data);
// lod_tensor_out->CopyToCpu(data_out);
// databuf_char = reinterpret_cast<char*>(data_out);
// }
// Because task scheduling requires OPs to use 'Channel' // Because task scheduling requires OPs to use 'Channel'
// (which is a data structure) to transfer data between OPs. // (which is a data structure) to transfer data between OPs.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册