提交 9b5c51d5 编写于 作者: S ShiningZhang

client support fp16

上级 cecd2c12
...@@ -23,8 +23,7 @@ using configure::GeneralModelConfig; ...@@ -23,8 +23,7 @@ using configure::GeneralModelConfig;
using baidu::paddle_serving::predictor::general_model::Request; using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response; using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Tensor; using baidu::paddle_serving::predictor::general_model::Tensor;
// paddle inference 2.1 support: FLOAT32, INT64, INT32, UINT8, INT8 // support: FLOAT32, INT64, INT32, UINT8, INT8, FLOAT16
// will support: FLOAT16
enum ProtoDataType { enum ProtoDataType {
P_INT64 = 0, P_INT64 = 0,
P_FLOAT32, P_FLOAT32,
...@@ -431,7 +430,8 @@ int PredictorOutputs::ParseProto(const Response& res, ...@@ -431,7 +430,8 @@ int PredictorOutputs::ParseProto(const Response& res,
output.tensor(idx).int_data().begin(), output.tensor(idx).int_data().begin(),
output.tensor(idx).int_data().begin() + size); output.tensor(idx).int_data().begin() + size);
} else if (fetch_name_to_type[name] == P_UINT8 } else if (fetch_name_to_type[name] == P_UINT8
|| fetch_name_to_type[name] == P_INT8) { || fetch_name_to_type[name] == P_INT8
|| fetch_name_to_type[name] == P_FP16) {
VLOG(2) << "fetch var [" << name << "]type=" VLOG(2) << "fetch var [" << name << "]type="
<< fetch_name_to_type[name]; << fetch_name_to_type[name];
string_data_map[name] = output.tensor(idx).tensor_content(); string_data_map[name] = output.tensor(idx).tensor_content();
......
...@@ -25,8 +25,7 @@ using baidu::paddle_serving::Timer; ...@@ -25,8 +25,7 @@ using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::general_model::Request; using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response; using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Tensor; using baidu::paddle_serving::predictor::general_model::Tensor;
// paddle inference support: FLOAT32, INT64, INT32, UINT8, INT8 // support: FLOAT32, INT64, INT32, UINT8, INT8, FLOAT16
// will support: FLOAT16
enum ProtoDataType { enum ProtoDataType {
P_INT64 = 0, P_INT64 = 0,
P_FLOAT32, P_FLOAT32,
......
...@@ -31,8 +31,7 @@ using baidu::paddle_serving::predictor::MempoolWrapper; ...@@ -31,8 +31,7 @@ using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor; using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request; using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
// paddle inference 2.1 support: FLOAT32, INT64, INT32, UINT8, INT8 // support: FLOAT32, INT64, INT32, UINT8, INT8, FLOAT16
// will support: FLOAT16
enum ProtoDataType { enum ProtoDataType {
P_INT64 = 0, P_INT64 = 0,
P_FLOAT32, P_FLOAT32,
......
...@@ -551,6 +551,22 @@ class Client(object): ...@@ -551,6 +551,22 @@ class Client(object):
tmp_lod = result_batch_handle.get_lod(mi, name) tmp_lod = result_batch_handle.get_lod(mi, name)
if np.size(tmp_lod) > 0: if np.size(tmp_lod) > 0:
result_map["{}.lod".format(name)] = tmp_lod result_map["{}.lod".format(name)] = tmp_lod
elif self.fetch_names_to_type_[name] == float16_type:
# result_map[name] will be py::array(numpy array)
tmp_str = result_batch_handle.get_string_by_name(
mi, name)
result_map[name] = np.fromstring(tmp_str, dtype = np.float16)
if result_map[name].size == 0:
raise ValueError(
"Failed to fetch, maybe the type of [{}]"
" is wrong, please check the model file".format(
name))
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape
if name in self.lod_tensor_set:
tmp_lod = result_batch_handle.get_lod(mi, name)
if np.size(tmp_lod) > 0:
result_map["{}.lod".format(name)] = tmp_lod
multi_result_map.append(result_map) multi_result_map.append(result_map)
ret = None ret = None
if len(model_engine_names) == 1: if len(model_engine_names) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册