未验证 提交 a152d43d 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1369 from TeslaZhao/develop

Python pipeline mode supports tensor structure input and output
......@@ -341,7 +341,6 @@ class Client(object):
string_feed_names = []
string_lod_slot_batch = []
string_shape = []
fetch_names = []
for key in fetch_list:
......
......@@ -45,7 +45,9 @@ class ChannelDataErrcode(enum.Enum):
CLOSED_ERROR = 6
NO_SERVICE = 7
UNKNOW = 8
PRODUCT_ERROR = 9
INPUT_PARAMS_ERROR = 9
PRODUCT_ERROR = 100
class ProductErrCode(enum.Enum):
......
......@@ -18,22 +18,110 @@ option go_package = "./;pipeline_serving";
import "google/api/annotations.proto";
// Tensor structure, consistent with PADDLE variable types.
// Descriptions of input and output data.
message Tensor {
// VarType: INT64
repeated int64 int64_data = 1;
// VarType: FP32, FP16
repeated float float_data = 2;
// VarType: INT32, INT16, INT8
repeated int32 int_data = 3;
// VarType: FP64
repeated double float64_data = 4;
// VarType: BF16, UINT8
repeated uint32 uint32_data = 5;
// VarType: BOOL
repeated bool bool_data = 6;
// (No support)VarType: COMPLEX64, 2x represents the real part, 2x+1
// represents the imaginary part
repeated float complex64_data = 7;
// (No support)VarType: COMPLEX128, 2x represents the real part, 2x+1
// represents the imaginary part
repeated double complex128_data = 8;
// VarType: STRING
repeated string str_data = 9;
// Element types:
// 0 => INT64
// 1 => FP32
// 2 => INT32
// 3 => FP64
// 4 => INT16
// 5 => FP16
// 6 => BF16
// 7 => UINT8
// 8 => INT8
// 9 => BOOL
// 10 => COMPLEX64
// 11 => COMPLEX128
// 12 => STRING
int32 elem_type = 10;
// Shape of the tensor, including batch dimensions.
repeated int32 shape = 11;
// Level of data(LOD), support variable length data, only for fetch tensor
// currently.
repeated int32 lod = 12;
// Correspond to the variable 'name' in the model description prototxt.
string name = 13;
};
// The structure of the service request. The input data can be repeated string
// pairs or tensors.
message Request {
// The input data are repeated string pairs.
// for examples. key is "words", value is the string of words.
repeated string key = 1;
repeated string value = 2;
// The input data are repeated tensors for complex data structures.
// Becase tensors can save more data information and reduce the amount of data
// transferred.
repeated Tensor tensors = 3;
// The name field in the RESTful API
string name = 4;
// The method field in the RESTful API
string method = 5;
// For tracing requests and logs
int64 logid = 6;
// For tracking sources
string clientip = 7;
};
// The structure of the service response. The output data can be repeated string
// pairs or tensors.
message Response {
// Error code
int32 err_no = 1;
// Error messages
string err_msg = 2;
// The results of string pairs
repeated string key = 3;
repeated string value = 4;
};
message Request {
repeated string key = 1;
repeated string value = 2;
string name = 3;
string method = 4;
int64 logid = 5;
string clientip = 6;
// The results of tensors
repeated Tensor tensors = 5;
};
// Python pipeline service
service PipelineService {
rpc inference(Request) returns (Response) {
option (google.api.http) = {
......
......@@ -45,6 +45,23 @@ from .pipeline_client import PipelineClient as PPClient
_LOGGER = logging.getLogger(__name__)
_op_name_gen = NameGenerator("Op")
# data type of tensor to numpy_data
_TENSOR_DTYPE_2_NUMPY_DATA_DTYPE = {
0: "int64", # VarType.INT64
1: "float32", # VarType.FP32
2: "int32", # VarType.INT32
3: "float64", # VarType.FP64
4: "int16", # VarType.int16
5: "float16", # VarType.FP32
6: "uint16", # VarType.BF16
7: "uint8", # VarType.UINT8
8: "int8", # VarType.INT8
9: "bool", # VarType.BOOL
10: "complex64", # VarType.COMPLEX64
11: "complex128", # VarType.COMPLEX128
12: "string", # dismatch with numpy
}
class Op(object):
def __init__(self,
......@@ -85,6 +102,9 @@ class Op(object):
self._server_use_profile = False
self._tracer = None
# for grpc_pipeline predict mode. False, string key/val; True, tensor format.
self._pack_tensor_format = False
# only for thread op
self._for_init_op_lock = threading.Lock()
self._for_close_op_lock = threading.Lock()
......@@ -372,6 +392,9 @@ class Op(object):
os._exit(-1)
self._input_ops.append(op)
def set_pack_tensor_format(self, is_tensor_format=False):
self._pack_tensor_format = is_tensor_format
def get_jump_to_ops(self):
return self._jump_to_ops
......@@ -577,6 +600,7 @@ class Op(object):
feed_dict=feed_batch[0],
fetch=self._fetch_names,
asyn=False,
pack_tensor_format=self._pack_tensor_format,
profile=False)
if call_result is None:
_LOGGER.error(
......@@ -1530,6 +1554,85 @@ class RequestOp(Op):
_LOGGER.critical("Op(Request) Failed to init: {}".format(e))
os._exit(-1)
def proto_tensor_2_numpy(self, tensor):
"""
Convert proto tensor to numpy array, The supported types are as follows:
INT64
FP32
INT32
FP64
INT16
FP16
BF16
UINT8
INT8
BOOL
Unsupported type:
COMPLEX64
COMPLEX128
STRING
Args:
tensor: one tensor in request.tensors.
Returns:
np.ndnumpy
"""
if tensor is None or tensor.elem_type is None or tensor.name is None:
_LOGGER.error("input params of tensor is wrong. tensor: {}".format(
tensor))
return None
dims = []
if tensor.shape is None:
dims.append(1)
else:
for one_dim in tensor.shape:
dims.append(one_dim)
np_data = None
_LOGGER.info("proto_to_numpy, name:{}, type:{}, dims:{}".format(
tensor.name, tensor.elem_type, dims))
if tensor.elem_type == 0:
# VarType: INT64
np_data = np.array(tensor.int64_data).astype(int64).reshape(dims)
elif tensor.elem_type == 1:
# VarType: FP32
np_data = np.array(tensor.float_data).astype(float32).reshape(dims)
elif tensor.elem_type == 2:
# VarType: INT32
np_data = np.array(tensor.int_data).astype(int32).reshape(dims)
elif tensor.elem_type == 3:
# VarType: FP64
np_data = np.array(tensor.float64_data).astype(float64).reshape(
dims)
elif tensor.elem_type == 4:
# VarType: INT16
np_data = np.array(tensor.int_data).astype(int16).reshape(dims)
elif tensor.elem_type == 5:
# VarType: FP16
np_data = np.array(tensor.float_data).astype(float16).reshape(dims)
elif tensor.elem_type == 6:
# VarType: BF16
np_data = np.array(tensor.uint32_data).astype(uint16).reshape(dims)
elif tensor.elem_type == 7:
# VarType: UINT8
np_data = np.array(tensor.uint32_data).astype(uint8).reshape(dims)
elif tensor.elem_type == 8:
# VarType: INT8
np_data = np.array(tensor.int_data).astype(int8).reshape(dims)
elif tensor.elem_type == 9:
# VarType: BOOL
np_data = np.array(tensor.bool_data).astype(bool).reshape(dims)
else:
_LOGGER.error("Sorry, the type {} of tensor {} is not supported.".
format(tensor.elem_type, tensor.name))
raise ValueError(
"Sorry, the type {} of tensor {} is not supported.".format(
tensor.elem_type, tensor.name))
return np_data
def unpack_request_package(self, request):
"""
Unpack request package by gateway.proto
......@@ -1550,9 +1653,43 @@ class RequestOp(Op):
_LOGGER.critical("request is None")
raise ValueError("request is None")
# unpack key/value string list
for idx, key in enumerate(request.key):
dict_data[key] = request.value[idx]
log_id = request.logid
# unpack proto.tensors data.
for one_tensor in request.tensors:
name = one_tensor.name
elem_type = one_tensor.elem_type
if one_tensor.name is None:
_LOGGER.error("Tensor name is None.")
raise ValueError("Tensor name is None.")
numpy_dtype = _TENSOR_DTYPE_2_NUMPY_DATA_DTYPE.get(elem_type)
if numpy_dtype is None:
_LOGGER.error(
"elem_type:{} is dismatch in unpack_request_package.",
format(elem_type))
raise ValueError("elem_type:{} error".format(elem_type))
if numpy_dtype == "string":
new_string = ""
if one_tensor.str_data is None:
_LOGGER.error(
"str_data of tensor:{} is None, elem_type is {}.".
format(name, elem_type))
raise ValueError(
"str_data of tensor:{} is None, elem_type is {}.".
format(name, elem_type))
for one_str in one_tensor.str_data:
new_string += one_str
dict_data[name] = new_string
else:
dict_data[name] = self.proto_tensor_2_numpy(one_tensor)
_LOGGER.debug("RequestOp unpack one request. log_id:{}, clientip:{} \
name:{}, method:{}".format(log_id, request.clientip, request.name,
request.method))
......@@ -1574,6 +1711,7 @@ class ResponseOp(Op):
"""
super(ResponseOp, self).__init__(
name="@DAGExecutor", input_ops=input_ops)
# init op
try:
self.init_op()
......@@ -1582,6 +1720,12 @@ class ResponseOp(Op):
e, exc_info=True))
os._exit(-1)
# init ResponseOp
self.is_pack_tensor = False
def set_pack_format(self, isTensor=False):
self.is_pack_tensor = isTensor
def pack_response_package(self, channeldata):
"""
Getting channeldata from the last channel, packting the response
......
......@@ -46,7 +46,7 @@ class PipelineClient(object):
self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
self._channel)
def _pack_request_package(self, feed_dict, profile):
def _pack_request_package(self, feed_dict, pack_tensor_format, profile):
req = pipeline_service_pb2.Request()
logid = feed_dict.get("logid")
......@@ -69,25 +69,88 @@ class PipelineClient(object):
feed_dict.pop("clientip")
np.set_printoptions(threshold=sys.maxsize)
for key, value in feed_dict.items():
req.key.append(key)
if (sys.version_info.major == 2 and isinstance(value,
(str, unicode)) or
((sys.version_info.major == 3) and isinstance(value, str))):
req.value.append(value)
continue
if isinstance(value, np.ndarray):
req.value.append(value.__repr__())
elif isinstance(value, list):
req.value.append(np.array(value).__repr__())
else:
raise TypeError("only str and np.ndarray type is supported: {}".
format(type(value)))
if profile:
req.key.append(self._profile_key)
req.value.append(self._profile_value)
if pack_tensor_format is False:
# pack string key/val format
for key, value in feed_dict.items():
req.key.append(key)
if (sys.version_info.major == 2 and
isinstance(value, (str, unicode)) or
((sys.version_info.major == 3) and isinstance(value, str))):
req.value.append(value)
continue
if isinstance(value, np.ndarray):
req.value.append(value.__repr__())
elif isinstance(value, list):
req.value.append(np.array(value).__repr__())
else:
raise TypeError(
"only str and np.ndarray type is supported: {}".format(
type(value)))
if profile:
req.key.append(self._profile_key)
req.value.append(self._profile_value)
else:
# pack tensor format
for key, value in feed_dict.items():
one_tensor = req.tensors.add()
one_tensor.name = key
if (sys.version_info.major == 2 and
isinstance(value, (str, unicode)) or
((sys.version_info.major == 3) and isinstance(value, str))):
one_tensor.string_data.add(value)
one_tensor.elem_type = 12 #12 => string
continue
if isinstance(value, np.ndarray):
# copy shape
_LOGGER.info("value shape is {}".format(value.shape))
for one_dim in value.shape:
one_tensor.shape.append(one_dim)
flat_value = value.flatten().tolist()
# copy data
if value.dtype == "int64":
one_tensor.int64_data.extend(flat_value)
one_tensor.elem_type = 0
elif value.dtype == "float32":
one_tensor.float_data.extend(flat_value)
one_tensor.elem_type = 1
elif value.dtype == "int32":
one_tensor.int_data.extend(flat_value)
one_tensor.elem_type = 2
elif value.dtype == "float64":
one_tensor.float64_data.extend(flat_value)
one_tensor.elem_type = 3
elif value.dtype == "int16":
one_tensor.int_data.extend(flat_value)
one_tensor.elem_type = 4
elif value.dtype == "float16":
one_tensor.float_data.extend(flat_value)
one_tensor.elem_type = 5
elif value.dtype == "uint16":
one_tensor.uint32_data.extend(flat_value)
one_tensor.elem_type = 6
elif value.dtype == "uint8":
one_tensor.uint32_data.extend(flat_value)
one_tensor.elem_type = 7
elif value.dtype == "int8":
one_tensor.int_data.extend(flat_value)
one_tensor.elem_type = 8
elif value.dtype == "bool":
one_tensor.bool_data.extend(flat_value)
one_tensor.elem_type = 9
else:
_LOGGER.error(
"value type {} of tensor {} is not supported.".
format(value.dtype, key))
else:
raise TypeError(
"only str and np.ndarray type is supported: {}".format(
type(value)))
return req
def _unpack_response_package(self, resp, fetch):
......@@ -97,6 +160,7 @@ class PipelineClient(object):
feed_dict,
fetch=None,
asyn=False,
pack_tensor_format=False,
profile=False,
log_id=0):
if not isinstance(feed_dict, dict):
......@@ -104,7 +168,8 @@ class PipelineClient(object):
"feed must be dict type with format: {name: value}.")
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict, profile)
req = self._pack_request_package(feed_dict, pack_tensor_format, profile)
req.logid = log_id
if not asyn:
resp = self._stub.inference(req)
......
......@@ -12,25 +12,113 @@
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
syntax = "proto3";
package baidu.paddle_serving.pipeline_serving;
// Tensor structure, consistent with PADDLE variable types.
// Descriptions of input and output data.
message Tensor {
// VarType: INT64
repeated int64 int64_data = 1;
// VarType: FP32, FP16
repeated float float_data = 2;
// VarType: INT32, INT16, INT8
repeated int32 int_data = 3;
// VarType: FP64
repeated double float64_data = 4;
// VarType: BF16, UINT8
repeated uint32 uint32_data = 5;
// VarType: BOOL
repeated bool bool_data = 6;
// (No support)VarType: COMPLEX64, 2x represents the real part, 2x+1
// represents the imaginary part
repeated float complex64_data = 7;
// (No support)VarType: COMPLEX128, 2x represents the real part, 2x+1
// represents the imaginary part
repeated double complex128_data = 8;
// VarType: STRING
repeated string str_data = 9;
// Element types:
// 0 => INT64
// 1 => FP32
// 2 => INT32
// 3 => FP64
// 4 => INT16
// 5 => FP16
// 6 => BF16
// 7 => UINT8
// 8 => INT8
// 9 => BOOL
// 10 => COMPLEX64
// 11 => COMPLEX128
// 12 => STRING
int32 elem_type = 10;
// Shape of the tensor, including batch dimensions.
repeated int32 shape = 11;
// Level of data(LOD), support variable length data, only for fetch tensor
// currently.
repeated int32 lod = 12;
// Correspond to the variable 'name' in the model description prototxt.
string name = 13;
};
// The structure of the service request. The input data can be repeated string
// pairs or tensors.
message Request {
// The input data are repeated string pairs.
// for examples. key is "words", value is the string of words.
repeated string key = 1;
repeated string value = 2;
optional string name = 3;
optional string method = 4;
optional int64 logid = 5;
optional string clientip = 6;
// The input data are repeated tensors for complex data structures.
// Becase tensors can save more data information and reduce the amount of data
// transferred.
repeated Tensor tensors = 3;
// The name field in the RESTful API
string name = 4;
// The method field in the RESTful API
string method = 5;
// For tracing requests and logs
int64 logid = 6;
// For tracking sources
string clientip = 7;
};
// The structure of the service response. The output data can be repeated string
// pairs or tensors.
message Response {
optional int32 err_no = 1;
optional string err_msg = 2;
// Error code
int32 err_no = 1;
// Error messages
string err_msg = 2;
// The results of string pairs
repeated string key = 3;
repeated string value = 4;
// The results of tensors
repeated Tensor tensors = 5;
};
// Python pipeline service
service PipelineService {
rpc inference(Request) returns (Response) {}
};
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册