提交 8f5c838b 编写于 作者: T TeslaZhao

Add bytes type in tensor proto, when data is large, pack/unpack performace is...

Add bytes type in tensor proto, when data is large, pack/unpack performace is better than repeated value
上级 6477fe08
...@@ -349,7 +349,7 @@ T* VersionedInferEngine::get_core() { ...@@ -349,7 +349,7 @@ T* VersionedInferEngine::get_core() {
} }
template <typename T> template <typename T>
T* VersionedInferEngine::get_core(uint64_t version) { T* VersionedInferEngine::get_core(const uint64_t version) {
auto iter = _versions.find(version); auto iter = _versions.find(version);
if (iter == _versions.end()) { if (iter == _versions.end()) {
LOG(ERROR) << "Not found version engine: " << version; LOG(ERROR) << "Not found version engine: " << version;
...@@ -539,7 +539,7 @@ int InferManager::infer(const char* model_name, ...@@ -539,7 +539,7 @@ int InferManager::infer(const char* model_name,
} }
template <typename T> template <typename T>
T* InferManager::get_core(const char* model_name, uint64_t version) { T* InferManager::get_core(const char* model_name, const uint64_t version) {
auto it = _map.find(model_name); auto it = _map.find(model_name);
if (it == _map.end()) { if (it == _map.end()) {
LOG(WARNING) << "Cannot find engine in map, model name:" << model_name; LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
......
...@@ -277,7 +277,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -277,7 +277,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
LOG(WARNING) << "Loading cube cache[" << next_idx << "] ..."; LOG(WARNING) << "Loading cube cache[" << next_idx << "] ...";
std::string model_path = conf.model_dir(); std::string model_path = conf.model_dir();
if (access(model_path.c_str(), F_OK) == 0) { if (access(model_path.c_str(), F_OK) == 0) {
std::string cube_cache_path = model_path + "cube_cache"; std::string cube_cache_path = model_path + "/" + "cube_cache";
int reload_cache_ret = md->caches[next_idx]->reload_data(cube_cache_path); int reload_cache_ret = md->caches[next_idx]->reload_data(cube_cache_path);
LOG(WARNING) << "Loading cube cache[" << next_idx << "] done."; LOG(WARNING) << "Loading cube cache[" << next_idx << "] done.";
} else { } else {
...@@ -543,9 +543,9 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> { ...@@ -543,9 +543,9 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
} else { } else {
LOG(ERROR) << "Inference not support type[" LOG(ERROR) << "Inference not support type["
<< (*tensorVector_in_pointer)[i].dtype << (*tensorVector_in_pointer)[i].dtype << "],name["
<< "],name[" << (*tensorVector_in_pointer)[i].name << (*tensorVector_in_pointer)[i].name << "]"
<< "]" << " copy into core failed!"; << " copy into core failed!";
} }
// Paddle inference will support FP16 in next version. // Paddle inference will support FP16 in next version.
// else if ((*tensorVector_in_pointer)[i].dtype == // else if ((*tensorVector_in_pointer)[i].dtype ==
...@@ -724,7 +724,7 @@ class VersionedInferEngine : public InferEngine { ...@@ -724,7 +724,7 @@ class VersionedInferEngine : public InferEngine {
int infer(const void* in, void* out, uint32_t batch_size, uint64_t version); int infer(const void* in, void* out, uint32_t batch_size, uint64_t version);
template <typename T> template <typename T>
T* get_core(uint64_t version); T* get_core(const uint64_t version);
int proc_initialize_impl(const configure::EngineDesc& conf, bool); int proc_initialize_impl(const configure::EngineDesc& conf, bool);
...@@ -789,7 +789,7 @@ class InferManager { ...@@ -789,7 +789,7 @@ class InferManager {
// Versioned get engine core // Versioned get engine core
template <typename T> template <typename T>
T* get_core(const char* model_name, uint64_t version); T* get_core(const char* model_name, const uint64_t version);
// query model version // query model version
int query_version(const std::string& model, uint64_t& version); int query_version(const std::string& model, uint64_t& version);
......
...@@ -51,6 +51,12 @@ message Tensor { ...@@ -51,6 +51,12 @@ message Tensor {
// VarType: STRING // VarType: STRING
repeated string str_data = 9; repeated string str_data = 9;
// VarType: BYTES, is suitable for big data. No need to save data types and
// dimensions
// pack method: pack by BytesIO, saved by np.save
// unpack method: load by np.load, unpack by BytesIO.
bytes byte_data = 10;
// Element types: // Element types:
// 0 => INT64 // 0 => INT64
// 1 => FP32 // 1 => FP32
...@@ -65,17 +71,18 @@ message Tensor { ...@@ -65,17 +71,18 @@ message Tensor {
// 10 => COMPLEX64 // 10 => COMPLEX64
// 11 => COMPLEX128 // 11 => COMPLEX128
// 12 => STRING // 12 => STRING
int32 elem_type = 10; // 13 => BYTES
int32 elem_type = 20;
// Shape of the tensor, including batch dimensions. // Shape of the tensor, including batch dimensions.
repeated int32 shape = 11; repeated int32 shape = 21;
// Level of data(LOD), support variable length data, only for fetch tensor // Level of data(LOD), support variable length data, only for fetch tensor
// currently. // currently.
repeated int32 lod = 12; repeated int32 lod = 22;
// Correspond to the variable 'name' in the model description prototxt. // Correspond to the variable 'name' in the model description prototxt.
string name = 13; string name = 23;
}; };
// The structure of the service request. The input data can be repeated string // The structure of the service request. The input data can be repeated string
......
...@@ -26,6 +26,7 @@ import collections ...@@ -26,6 +26,7 @@ import collections
import numpy as np import numpy as np
import json import json
from numpy import * from numpy import *
from io import BytesIO
if sys.version_info.major == 2: if sys.version_info.major == 2:
import Queue import Queue
elif sys.version_info.major == 3: elif sys.version_info.major == 3:
...@@ -59,7 +60,8 @@ _TENSOR_DTYPE_2_NUMPY_DATA_DTYPE = { ...@@ -59,7 +60,8 @@ _TENSOR_DTYPE_2_NUMPY_DATA_DTYPE = {
9: "bool", # VarType.BOOL 9: "bool", # VarType.BOOL
10: "complex64", # VarType.COMPLEX64 10: "complex64", # VarType.COMPLEX64
11: "complex128", # VarType.COMPLEX128 11: "complex128", # VarType.COMPLEX128
12: "string", # dismatch with numpy 12: "string", # load by numpy
13: "bytes", # load by numpy
} }
...@@ -1577,10 +1579,11 @@ class RequestOp(Op): ...@@ -1577,10 +1579,11 @@ class RequestOp(Op):
UINT8 UINT8
INT8 INT8
BOOL BOOL
BYTES
Unsupported type: Unsupported type:
STRING
COMPLEX64 COMPLEX64
COMPLEX128 COMPLEX128
STRING
Args: Args:
tensor: one tensor in request.tensors. tensor: one tensor in request.tensors.
...@@ -1634,6 +1637,10 @@ class RequestOp(Op): ...@@ -1634,6 +1637,10 @@ class RequestOp(Op):
elif tensor.elem_type == 9: elif tensor.elem_type == 9:
# VarType: BOOL # VarType: BOOL
np_data = np.array(tensor.bool_data).astype(bool).reshape(dims) np_data = np.array(tensor.bool_data).astype(bool).reshape(dims)
elif tensor.elem_type == 13:
# VarType: BYTES
byte_data = BytesIO(tensor.byte_data)
np_data = np.load(byte_data, allow_pickle=True)
else: else:
_LOGGER.error("Sorry, the type {} of tensor {} is not supported.". _LOGGER.error("Sorry, the type {} of tensor {} is not supported.".
format(tensor.elem_type, tensor.name)) format(tensor.elem_type, tensor.name))
......
...@@ -25,6 +25,7 @@ from .channel import ChannelDataErrcode ...@@ -25,6 +25,7 @@ from .channel import ChannelDataErrcode
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
import six import six
from io import BytesIO
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
...@@ -47,7 +48,8 @@ class PipelineClient(object): ...@@ -47,7 +48,8 @@ class PipelineClient(object):
self._stub = pipeline_service_pb2_grpc.PipelineServiceStub( self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
self._channel) self._channel)
def _pack_request_package(self, feed_dict, pack_tensor_format, profile): def _pack_request_package(self, feed_dict, pack_tensor_format,
use_tensor_bytes, profile):
req = pipeline_service_pb2.Request() req = pipeline_service_pb2.Request()
logid = feed_dict.get("logid") logid = feed_dict.get("logid")
...@@ -99,11 +101,9 @@ class PipelineClient(object): ...@@ -99,11 +101,9 @@ class PipelineClient(object):
one_tensor = req.tensors.add() one_tensor = req.tensors.add()
one_tensor.name = key one_tensor.name = key
if (sys.version_info.major == 2 and if isinstance(value, str):
isinstance(value, (str, unicode)) or
((sys.version_info.major == 3) and isinstance(value, str))):
one_tensor.string_data.add(value) one_tensor.string_data.add(value)
one_tensor.elem_type = 12 #12 => string one_tensor.elem_type = 12 #12 => string in proto
continue continue
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
...@@ -112,6 +112,13 @@ class PipelineClient(object): ...@@ -112,6 +112,13 @@ class PipelineClient(object):
for one_dim in value.shape: for one_dim in value.shape:
one_tensor.shape.append(one_dim) one_tensor.shape.append(one_dim)
# packed into bytes
if use_tensor_bytes is True:
np_bytes = BytesIO()
np.save(np_bytes, value, allow_pickle=True)
one_tensor.byte_data = np_bytes.getvalue()
one_tensor.elem_type = 13 #13 => bytes in proto
flat_value = value.flatten().tolist() flat_value = value.flatten().tolist()
# copy data # copy data
if value.dtype == "int64": if value.dtype == "int64":
...@@ -162,6 +169,7 @@ class PipelineClient(object): ...@@ -162,6 +169,7 @@ class PipelineClient(object):
fetch=None, fetch=None,
asyn=False, asyn=False,
pack_tensor_format=False, pack_tensor_format=False,
use_tensor_bytes=False,
profile=False, profile=False,
log_id=0): log_id=0):
if not isinstance(feed_dict, dict): if not isinstance(feed_dict, dict):
...@@ -170,7 +178,8 @@ class PipelineClient(object): ...@@ -170,7 +178,8 @@ class PipelineClient(object):
if fetch is not None and not isinstance(fetch, list): if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].") raise TypeError("fetch must be list type with format: [name].")
print("PipelineClient::predict pack_data time:{}".format(time.time())) print("PipelineClient::predict pack_data time:{}".format(time.time()))
req = self._pack_request_package(feed_dict, pack_tensor_format, profile) req = self._pack_request_package(feed_dict, pack_tensor_format,
use_tensor_bytes, profile)
req.logid = log_id req.logid = log_id
if not asyn: if not asyn:
print("PipelineClient::predict before time:{}".format(time.time())) print("PipelineClient::predict before time:{}".format(time.time()))
......
...@@ -48,6 +48,12 @@ message Tensor { ...@@ -48,6 +48,12 @@ message Tensor {
// VarType: STRING // VarType: STRING
repeated string str_data = 9; repeated string str_data = 9;
// VarType: BYTES, is suitable for big data. No need to save data types and
// dimensions
// pack method: pack by BytesIO, saved by np.save
// unpack method: load by np.load, unpack by BytesIO.
bytes byte_data = 10;
// Element types: // Element types:
// 0 => INT64 // 0 => INT64
// 1 => FP32 // 1 => FP32
...@@ -62,17 +68,18 @@ message Tensor { ...@@ -62,17 +68,18 @@ message Tensor {
// 10 => COMPLEX64 // 10 => COMPLEX64
// 11 => COMPLEX128 // 11 => COMPLEX128
// 12 => STRING // 12 => STRING
int32 elem_type = 10; // 13 => BYTES
int32 elem_type = 20;
// Shape of the tensor, including batch dimensions. // Shape of the tensor, including batch dimensions.
repeated int32 shape = 11; repeated int32 shape = 21;
// Level of data(LOD), support variable length data, only for fetch tensor // Level of data(LOD), support variable length data, only for fetch tensor
// currently. // currently.
repeated int32 lod = 12; repeated int32 lod = 22;
// Correspond to the variable 'name' in the model description prototxt. // Correspond to the variable 'name' in the model description prototxt.
string name = 13; string name = 23;
}; };
// The structure of the service request. The input data can be repeated string // The structure of the service request. The input data can be repeated string
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册