提交 e4fb6de7 编写于 作者: T TeslaZhao

Python pipeline mode supports loop OP

上级 e61515bf
......@@ -46,7 +46,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["save_infer_model/scale_0.tmp_1"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"inputs": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["save_infer_model/scale_0.tmp_0"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"inputs": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["save_infer_model/scale_0.tmp_1"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -46,7 +46,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["score"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -49,7 +49,7 @@ class ImagenetOp(Op):
input_imgs = np.concatenate(imgs, axis=0)
return {"image": input_imgs}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
score_list = fetch_dict["prediction"]
result = {"label": [], "prob": []}
for score in score_list:
......
......@@ -19,6 +19,7 @@ import cv2
from paddle_serving_app.reader import *
import base64
class FasterRCNNOp(Op):
def init_op(self):
self.img_preprocess = Sequential([
......@@ -38,22 +39,30 @@ class FasterRCNNOp(Op):
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
im = self.img_preprocess(im)
imgs.append({
"image": im[np.newaxis,:],
"im_shape": np.array(list(im.shape[1:])).reshape(-1)[np.newaxis,:],
"scale_factor": np.array([1.0, 1.0]).reshape(-1)[np.newaxis,:],
"image": im[np.newaxis, :],
"im_shape":
np.array(list(im.shape[1:])).reshape(-1)[np.newaxis, :],
"scale_factor": np.array([1.0, 1.0]).reshape(-1)[np.newaxis, :],
})
feed_dict = {
"image": np.concatenate([x["image"] for x in imgs], axis=0),
"im_shape": np.concatenate([x["im_shape"] for x in imgs], axis=0),
"scale_factor": np.concatenate([x["scale_factor"] for x in imgs], axis=0)
"image": np.concatenate(
[x["image"] for x in imgs], axis=0),
"im_shape": np.concatenate(
[x["im_shape"] for x in imgs], axis=0),
"scale_factor": np.concatenate(
[x["scale_factor"] for x in imgs], axis=0)
}
#for key in feed_dict.keys():
# print(key, feed_dict[key].shape)
return feed_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
#print(fetch_dict)
res_dict = {"bbox_result": str(self.img_postprocess(fetch_dict, visualize=False))}
res_dict = {
"bbox_result":
str(self.img_postprocess(
fetch_dict, visualize=False))
}
return res_dict, None, ""
......
......@@ -19,6 +19,7 @@ import cv2
from paddle_serving_app.reader import *
import base64
class PPYoloMbvOp(Op):
def init_op(self):
self.img_preprocess = Sequential([
......@@ -38,23 +39,31 @@ class PPYoloMbvOp(Op):
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
im = self.img_preprocess(im)
imgs.append({
"image": im[np.newaxis,:],
"im_shape": np.array(list(im.shape[1:])).reshape(-1)[np.newaxis,:],
"scale_factor": np.array([1.0, 1.0]).reshape(-1)[np.newaxis,:],
"image": im[np.newaxis, :],
"im_shape":
np.array(list(im.shape[1:])).reshape(-1)[np.newaxis, :],
"scale_factor": np.array([1.0, 1.0]).reshape(-1)[np.newaxis, :],
})
feed_dict = {
"image": np.concatenate([x["image"] for x in imgs], axis=0),
"im_shape": np.concatenate([x["im_shape"] for x in imgs], axis=0),
"scale_factor": np.concatenate([x["scale_factor"] for x in imgs], axis=0)
"image": np.concatenate(
[x["image"] for x in imgs], axis=0),
"im_shape": np.concatenate(
[x["im_shape"] for x in imgs], axis=0),
"scale_factor": np.concatenate(
[x["scale_factor"] for x in imgs], axis=0)
}
for key in feed_dict.keys():
print(key, feed_dict[key].shape)
return feed_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
#print(fetch_dict)
res_dict = {"bbox_result": str(self.img_postprocess(fetch_dict, visualize=False))}
res_dict = {
"bbox_result":
str(self.img_postprocess(
fetch_dict, visualize=False))
}
return res_dict, None, ""
......
......@@ -19,6 +19,7 @@ import cv2
from paddle_serving_app.reader import *
import base64
class Yolov3Op(Op):
def init_op(self):
self.img_preprocess = Sequential([
......@@ -38,22 +39,30 @@ class Yolov3Op(Op):
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
im = self.img_preprocess(im)
imgs.append({
"image": im[np.newaxis,:],
"im_shape": np.array(list(im.shape[1:])).reshape(-1)[np.newaxis,:],
"scale_factor": np.array([1.0, 1.0]).reshape(-1)[np.newaxis,:],
"image": im[np.newaxis, :],
"im_shape":
np.array(list(im.shape[1:])).reshape(-1)[np.newaxis, :],
"scale_factor": np.array([1.0, 1.0]).reshape(-1)[np.newaxis, :],
})
feed_dict = {
"image": np.concatenate([x["image"] for x in imgs], axis=0),
"im_shape": np.concatenate([x["im_shape"] for x in imgs], axis=0),
"scale_factor": np.concatenate([x["scale_factor"] for x in imgs], axis=0)
"image": np.concatenate(
[x["image"] for x in imgs], axis=0),
"im_shape": np.concatenate(
[x["im_shape"] for x in imgs], axis=0),
"scale_factor": np.concatenate(
[x["scale_factor"] for x in imgs], axis=0)
}
#for key in feed_dict.keys():
# print(key, feed_dict[key].shape)
return feed_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
#print(fetch_dict)
res_dict = {"bbox_result": str(self.img_postprocess(fetch_dict, visualize=False))}
res_dict = {
"bbox_result":
str(self.img_postprocess(
fetch_dict, visualize=False))
}
return res_dict, None, ""
......
......@@ -43,7 +43,7 @@ class BertOp(Op):
print(key, feed_dict[key].shape)
return feed_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
fetch_dict["pooled_output"] = str(fetch_dict["pooled_output"])
return fetch_dict, None, ""
......
......@@ -42,7 +42,7 @@ class ImagenetOp(Op):
img = self.seq(im)
return {"image": img[np.newaxis, :].copy()}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
print(fetch_dict)
score_list = fetch_dict["score"]
result = {"label": [], "prob": []}
......
......@@ -54,7 +54,7 @@ class DetOp(Op):
imgs.append(det_img[np.newaxis, :].copy())
return {"image": np.concatenate(imgs, axis=0)}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
# print(fetch_dict)
det_out = fetch_dict["concat_1.tmp_0"]
ratio_list = [
......@@ -149,7 +149,7 @@ class RecOp(Op):
return feed_list, False, None, ""
def postprocess(self, input_dicts, fetch_data, log_id):
def postprocess(self, input_dicts, fetch_data, data_id, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
......
......@@ -40,9 +40,10 @@ class UciOp(Op):
proc_dict = {}
return input_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
_LOGGER.info("UciOp::postprocess >>> log_id:{}, fetch_dict:{}".format(
log_id, fetch_dict))
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
_LOGGER.info(
"UciOp::postprocess >>> data_id:{}, log_id:{}, fetch_dict:{}".
format(data_id, log_id, fetch_dict))
fetch_dict["price"] = str(fetch_dict["price"])
return fetch_dict, None, ""
......
......@@ -41,9 +41,10 @@ class UciOp(Op):
return input_dict, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
_LOGGER.info("UciOp::postprocess >>> log_id:{}, fetch_dict:{}".format(
log_id, fetch_dict))
def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
_LOGGER.info(
"UciOp::postprocess >>> data_id:{}, log_id:{}, fetch_dict:{}".
format(data_id, log_id, fetch_dict))
fetch_dict["price"] = str(fetch_dict["price"][0][0])
return fetch_dict, None, ""
......
......@@ -127,7 +127,8 @@ class LocalPredictor(object):
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
self.fetch_types_[var.alias_name] = var.fetch_type
self.fetch_names_to_type_[var.alias_name] = var.shape
# set precision of inference.
precision_type = paddle_infer.PrecisionType.Float32
......@@ -253,8 +254,27 @@ class LocalPredictor(object):
feed[name] = feed[name].astype("float32")
elif self.feed_types_[name] == 2:
feed[name] = feed[name].astype("int32")
elif self.feed_types_[name] == 3:
feed[name] = feed[name].astype("float64")
elif self.feed_types_[name] == 4:
feed[name] = feed[name].astype("int16")
elif self.feed_types_[name] == 5:
feed[name] = feed[name].astype("float16")
elif self.feed_types_[name] == 6:
feed[name] = feed[name].astype("uint16")
elif self.feed_types_[name] == 7:
feed[name] = feed[name].astype("uint8")
elif self.feed_types_[name] == 8:
feed[name] = feed[name].astype("int8")
elif self.feed_types_[name] == 9:
feed[name] = feed[name].astype("bool")
elif self.feed_types_[name] == 10:
feed[name] = feed[name].astype("complex64")
elif self.feed_types_[name] == 11:
feed[name] = feed[name].astype("complex128")
else:
raise ValueError("local predictor receives wrong data type")
input_tensor_handle = self.predictor.get_input_handle(name)
if "{}.lod".format(name) in feed:
input_tensor_handle.set_lod([feed["{}.lod".format(name)]])
......
......@@ -337,8 +337,6 @@ class Client(object):
string_shape = []
fetch_names = []
counter = 0
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
......
......@@ -31,6 +31,21 @@ import paddle.nn.functional as F
import errno
from paddle.jit import to_static
_PADDLE_DTYPE_2_NUMPY_DTYPE = {
core.VarDesc.VarType.BOOL: 'bool',
core.VarDesc.VarType.FP16: 'float16',
core.VarDesc.VarType.BF16: 'uint16',
core.VarDesc.VarType.FP32: 'float32',
core.VarDesc.VarType.FP64: 'float64',
core.VarDesc.VarType.INT8: 'int8',
core.VarDesc.VarType.INT16: 'int16',
core.VarDesc.VarType.INT32: 'int32',
core.VarDesc.VarType.INT64: 'int64',
core.VarDesc.VarType.UINT8: 'uint8',
core.VarDesc.VarType.COMPLEX64: 'complex64',
core.VarDesc.VarType.COMPLEX128: 'complex128',
}
def save_dygraph_model(serving_model_folder, client_config_folder, model):
paddle.jit.save(model, "serving_tmp")
......@@ -57,13 +72,8 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
feed_var = model_conf.FeedVar()
feed_var.alias_name = key
feed_var.name = feed_var_dict[key].name
feed_var.feed_type = var_type_conversion(feed_var_dict[key].dtype)
feed_var.is_lod_tensor = feed_var_dict[key].lod_level >= 1
if feed_var_dict[key].dtype == core.VarDesc.VarType.INT64:
feed_var.feed_type = 0
if feed_var_dict[key].dtype == core.VarDesc.VarType.FP32:
feed_var.feed_type = 1
if feed_var_dict[key].dtype == core.VarDesc.VarType.INT32:
feed_var.feed_type = 2
if feed_var.is_lod_tensor:
feed_var.shape.extend([-1])
else:
......@@ -77,13 +87,8 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
fetch_var = model_conf.FetchVar()
fetch_var.alias_name = key
fetch_var.name = fetch_var_dict[key].name
fetch_var.fetch_type = var_type_conversion(fetch_var_dict[key].dtype)
fetch_var.is_lod_tensor = 1
if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT64:
fetch_var.fetch_type = 0
if fetch_var_dict[key].dtype == core.VarDesc.VarType.FP32:
fetch_var.fetch_type = 1
if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT32:
fetch_var.fetch_type = 2
if fetch_var.is_lod_tensor:
fetch_var.shape.extend([-1])
else:
......@@ -119,6 +124,59 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
fout.write(config.SerializeToString())
def var_type_conversion(dtype):
"""
Variable type conversion
Args:
dtype: type of core.VarDesc.VarType.xxxxx
(https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/framework/dtype.py)
Returns:
(int)type value, -1 is type matching failed.
int64 => 0;
float32 => 1;
int32 => 2;
float64 => 3;
int16 => 4;
float16 => 5;
bfloat16 => 6;
uint8 => 7;
int8 => 8;
bool => 9;
complex64 => 10,
complex128 => 11;
"""
type_val = -1
if dtype == core.VarDesc.VarType.INT64:
type_val = 0
elif dtype == core.VarDesc.VarType.FP32:
type_val = 1
elif dtype == core.VarDesc.VarType.INT32:
type_val = 2
elif dtype == core.VarDesc.VarType.FP64:
type_val = 3
elif dtype == core.VarDesc.VarType.INT16:
type_val = 4
elif dtype == core.VarDesc.VarType.FP16:
type_val = 5
elif dtype == core.VarDesc.VarType.BF16:
type_val = 6
elif dtype == core.VarDesc.VarType.UINT8:
type_val = 7
elif dtype == core.VarDesc.VarType.INT8:
type_val = 8
elif dtype == core.VarDesc.VarType.BOOL:
type_val = 9
elif dtype == core.VarDesc.VarType.COMPLEX64:
type_val = 10
elif dtype == core.VarDesc.VarType.COMPLEX128:
type_val = 11
else:
type_val = -1
return type_val
def save_model(server_model_folder,
client_config_folder,
feed_var_dict,
......@@ -164,18 +222,13 @@ def save_model(server_model_folder,
config = model_conf.GeneralModelConfig()
#int64 = 0; float32 = 1; int32 = 2;
for key in feed_var_dict:
feed_var = model_conf.FeedVar()
feed_var.alias_name = key
feed_var.name = feed_var_dict[key].name
feed_var.feed_type = var_type_conversion(feed_var_dict[key].dtype)
feed_var.is_lod_tensor = feed_var_dict[key].lod_level >= 1
if feed_var_dict[key].dtype == core.VarDesc.VarType.INT64:
feed_var.feed_type = 0
if feed_var_dict[key].dtype == core.VarDesc.VarType.FP32:
feed_var.feed_type = 1
if feed_var_dict[key].dtype == core.VarDesc.VarType.INT32:
feed_var.feed_type = 2
if feed_var.is_lod_tensor:
feed_var.shape.extend([-1])
else:
......@@ -190,14 +243,10 @@ def save_model(server_model_folder,
fetch_var = model_conf.FetchVar()
fetch_var.alias_name = key
fetch_var.name = fetch_var_dict[key].name
#fetch_var.is_lod_tensor = fetch_var_dict[key].lod_level >= 1
fetch_var.is_lod_tensor = 1
if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT64:
fetch_var.fetch_type = 0
if fetch_var_dict[key].dtype == core.VarDesc.VarType.FP32:
fetch_var.fetch_type = 1
if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT32:
fetch_var.fetch_type = 2
fetch_var.fetch_type = var_type_conversion(fetch_var_dict[key].dtype)
fetch_var.is_lod_tensor = fetch_var_dict[key].lod_level >= 1
#fetch_var.is_lod_tensor = 1
if fetch_var.is_lod_tensor:
fetch_var.shape.extend([-1])
else:
......
......@@ -101,7 +101,6 @@ def is_gpu_mode(unformatted_gpus):
for ids in op_gpu_list:
if int(ids) >= 0:
return True
return False
......
......@@ -140,7 +140,7 @@ class Server(object):
def set_ir_optimize(self, flag=False):
self.ir_optimization = flag
# Multi-Server does not have this Function.
# Multi-Server does not have this Function.
def set_product_name(self, product_name=None):
if product_name == None:
raise ValueError("product_name can't be None.")
......@@ -437,7 +437,6 @@ class Server(object):
def download_bin(self):
os.chdir(self.module_path)
need_download = False
#acquire lock
version_file = open("{}/version.py".format(self.module_path), "r")
......
......@@ -176,7 +176,7 @@ class DAGExecutor(object):
"in_channel must be Channel type, but get {}".
format(type(in_channel)))
os._exit(-1)
in_channel.add_producer(self.name)
self._in_channel = in_channel
_LOGGER.info("[DAG] set in channel succ, name [{}]".format(self.name))
......@@ -669,14 +669,14 @@ class DAG(object):
out_degree_ops)
dag_views = list(reversed(dag_views))
if not self._build_dag_each_worker:
_LOGGER.debug("================== DAG ====================")
_LOGGER.info("================== DAG ====================")
for idx, view in enumerate(dag_views):
_LOGGER.debug("(VIEW {})".format(idx))
_LOGGER.info("(VIEW {})".format(idx))
for op in view:
_LOGGER.debug(" [{}]".format(op.name))
_LOGGER.info(" [{}]".format(op.name))
for out_op in out_degree_ops[op.name]:
_LOGGER.debug(" - {}".format(out_op.name))
_LOGGER.debug("-------------------------------------------")
_LOGGER.info(" - {}".format(out_op.name))
_LOGGER.info("-------------------------------------------")
# create channels and virtual ops
virtual_op_name_gen = NameGenerator("vir")
......@@ -719,6 +719,7 @@ class DAG(object):
channel = self._gen_channel(channel_name_gen)
channels.append(channel)
op.add_input_channel(channel)
_LOGGER.info("op:{} add input channel.".format(op.name))
pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0:
input_channel = channel
......@@ -726,6 +727,8 @@ class DAG(object):
# if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops:
pred_op.add_output_channel(channel)
_LOGGER.info("pred_op:{} add output channel".format(
pred_op.name))
processed_op.add(op.name)
# find same input op to combine channel
for other_op in actual_next_view[o_idx + 1:]:
......@@ -745,6 +748,7 @@ class DAG(object):
output_channel = self._gen_channel(channel_name_gen)
channels.append(output_channel)
last_op.add_output_channel(output_channel)
_LOGGER.info("last op:{} add output channel".format(last_op.name))
pack_func, unpack_func = None, None
pack_func = response_op.pack_response_package
......@@ -752,7 +756,11 @@ class DAG(object):
actual_ops = virtual_ops
for op in used_ops:
if len(op.get_input_ops()) == 0:
#set special features of the request op.
#1.set unpack function.
#2.set output channel.
unpack_func = op.unpack_request_package
op.add_output_channel(input_channel)
continue
actual_ops.append(op)
......
......@@ -58,13 +58,15 @@ class Op(object):
retry=0,
batch_size=None,
auto_batching_timeout=None,
local_service_handler=None):
local_service_handler=None,
jump_to_ops=[]):
# In __init__, all the parameters are just saved and Op is not initialized
if name is None:
name = _op_name_gen.next()
self.name = name # to identify the type of OP, it must be globally unique
self.concurrency = concurrency # amount of concurrency
self.set_input_ops(input_ops)
self.set_jump_to_ops(jump_to_ops)
self._local_service_handler = local_service_handler
self._server_endpoints = server_endpoints
......@@ -99,9 +101,7 @@ class Op(object):
conf: config.yaml
Returns:
None
"""
# init op
if self.concurrency is None:
self.concurrency = conf["concurrency"]
if self._retry is None:
......@@ -372,6 +372,79 @@ class Op(object):
os._exit(-1)
self._input_ops.append(op)
def get_jump_to_ops(self):
return self._jump_to_ops
def set_jump_to_ops(self, ops):
"""
Set jump to ops, then, this op can send channeldata to output channel.
Args:
ops: op list to be jumpped
Returns:
None.
"""
if not isinstance(ops, list):
ops = [] if ops is None else [ops]
self._jump_to_ops = []
for op in ops:
if not isinstance(op, Op):
_LOGGER.critical(
self._log("Failed to set input_ops: input op "
"must be Op type, not {}".format(type(op))))
os._exit(-1)
self._jump_to_ops.append(op)
def is_jump_op(self):
"""
The op has _jump_to_ops members or not.
Args:
None
Returns:
True or False
"""
return len(self._jump_to_ops) > 0
def check_jumping(self, input_data):
"""
Check whether to send data to jump ops.WhileOp needs to rewrite
this interface. this function returns False default.
Args:
input_data: input data to be preprocessed
Returns:
True, send data to the output channel of jump ops
False, send data to output channel.
"""
return False
def get_output_channels_of_jump_ops(self):
"""
Get output channels of jump ops
Args:
None
Returns:
list of channels
"""
channels = []
if self.is_jump_op() is False:
return channels
for op in self._jump_to_ops:
_LOGGER.info("op:{} extend op._get_output_channels:{}".format(
op.name, op._get_output_channels()))
channels.extend(op._get_output_channels())
_LOGGER.info("get_output_channels_of_jump_ops, channels:{}".format(
channels))
return channels
def add_input_channel(self, channel):
"""
Adding one input channel to the Op. Each op have many front op,
......@@ -410,6 +483,7 @@ class Op(object):
os._exit(-1)
channel.add_producer(self.name)
self._outputs.append(channel)
_LOGGER.info("op:{} add output_channel {}".format(self.name, channel))
def clean_output_channels(self):
self._outputs = []
......@@ -424,7 +498,7 @@ class Op(object):
Args:
input_dicts: input data to be preprocessed
data_id: inner unique id, 0 default
data_id: inner unique id, increase auto
log_id: global unique id for RTT, 0 default
Return:
......@@ -484,12 +558,13 @@ class Op(object):
'''
return call_result
def postprocess(self, input_data, fetch_data, log_id=0):
def postprocess(self, input_data, fetch_data, data_id=0, log_id=0):
"""
In postprocess stage, assemble data for next op or output.
Args:
input_data: data returned in preprocess stage, dict(for single predict) or list(for batch predict)
fetch_data: data returned in process stage, dict(for single predict) or list(for batch predict)
data_id: inner unique id, increase auto
log_id: logid, 0 default
Returns:
......@@ -593,7 +668,8 @@ class Op(object):
self.device_type, self.devices, self.mem_optim,
self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list))
self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops()))
p.daemon = True
p.start()
process.append(p)
......@@ -629,7 +705,8 @@ class Op(object):
self.device_type, self.devices, self.mem_optim,
self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list))
self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops()))
# When a process exits, it attempts to terminate
# all of its daemonic child processes.
t.daemon = True
......@@ -954,7 +1031,7 @@ class Op(object):
prod_errcode, prod_errinfo = None, None
try:
postped_data, prod_errcode, prod_errinfo = self.postprocess(
parsed_data_dict[data_id], midped_data,
parsed_data_dict[data_id], midped_data, data_id,
logid_dict.get(data_id))
except Exception as e:
error_info = "(data_id={} log_id={}) {} Failed to postprocess: {}".format(
......@@ -1100,7 +1177,8 @@ class Op(object):
def _run(self, concurrency_idx, input_channel, output_channels,
is_thread_op, trace_buffer, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list):
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
is_jump_op, output_channels_of_jump_ops):
"""
_run() is the entry function of OP process / thread model.When client
type is local_predictor in process mode, the CUDA environment needs to
......@@ -1127,6 +1205,8 @@ class Op(object):
mkldnn_cache_capacity: cache capacity of mkldnn, 0 means no limit.
mkldnn_op_list: OP list optimized by mkldnn, None default.
mkldnn_bf16_op_list: OP list optimized by mkldnn bf16, None default.
is_jump_op: OP has jump op list or not, False default.
output_channels_of_jump_ops: all output channels of jump ops.
Returns:
None
......@@ -1267,27 +1347,46 @@ class Op(object):
break
if len(postped_data_dict) == 0:
continue
# push data to channel (if run succ)
start = int(round(_time() * 1000000))
try:
profile_str = profiler.gen_profile_str()
for data_id, postped_data in postped_data_dict.items():
if self._server_use_profile:
sys.stderr.write(profile_str)
self._push_to_output_channels(
data=postped_data,
channels=output_channels,
profile_str=profile_str,
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
after_outchannel_time = _time()
_LOGGER.debug(
"(data_id={}) PUSH OUTPUT CHANNEL! op:{} push cost:{} ms".
format(data_id, self.name, (after_outchannel_time -
after_postp_time) * 1000))
_LOGGER.debug(
"(data_id={}) PUSH OUTPUT CHANNEL! op:{} push data:{}".
format(data_id, self.name, postped_data.get_all_data()))
if self.is_jump_op() is True and self.check_jumping(
postped_data_dict) is True:
# push data to output channel of ops to be jumped
for data_id, postped_data in postped_data_dict.items():
if self._server_use_profile:
sys.stderr.write(profile_str)
self._push_to_output_channels(
data=postped_data,
channels=output_channels_of_jump_ops,
profile_str=profile_str,
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
after_outchannel_time = _time()
_LOGGER.debug(
"(data_id={}) PUSH OUTPUT CHANNEL OF JUMP OPs! op:{} push cost:{} ms".
format(data_id, self.name, (after_outchannel_time -
after_postp_time) *
1000))
else:
# push data to output channel.
for data_id, postped_data in postped_data_dict.items():
if self._server_use_profile:
sys.stderr.write(profile_str)
self._push_to_output_channels(
data=postped_data,
channels=output_channels,
profile_str=profile_str,
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
after_outchannel_time = _time()
_LOGGER.debug(
"(data_id={}) PUSH OUTPUT CHANNEL! op:{} push cost:{} ms".
format(data_id, self.name, (after_outchannel_time -
after_postp_time) *
1000))
except ChannelStopError:
_LOGGER.debug("{} Stop.".format(op_info_prefix))
self._finalize(is_thread_op)
......@@ -1410,7 +1509,7 @@ class RequestOp(Op):
for idx, key in enumerate(request.key):
dict_data[key] = request.value[idx]
log_id = request.logid
_LOGGER.info("RequestOp unpack one request. log_id:{}, clientip:{} \
_LOGGER.debug("RequestOp unpack one request. log_id:{}, clientip:{} \
name:{}, method:{}".format(log_id, request.clientip, request.name,
request.method))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册