提交 987ab664 编写于 作者: H HexToString

update http_client

上级 ff034d3f
......@@ -22,12 +22,8 @@ message EngineDesc {
required string reloadable_type = 4;
required string model_dir = 5;
repeated int32 gpu_ids = 6;
optional int32 runtime_thread_num = 7 [ default = 0 ];
optional int32 batch_infer_size = 8 [ default = 32 ];
optional bool enable_overrun = 9 [ default = false ];
optional bool allow_split_request = 10 [ default = true ];
optional string version_file = 11;
optional string version_type = 12;
optional string version_file = 7;
optional string version_type = 8;
/*
* Sparse Parameter Service type. Valid types are:
......@@ -40,17 +36,34 @@ message EngineDesc {
LOCAL = 1;
REMOTE = 2;
}
optional SparseParamServiceType sparse_param_service_type = 13;
optional string sparse_param_service_table_name = 14;
optional bool enable_memory_optimization = 15;
optional bool enable_ir_optimization = 16;
optional bool use_trt = 17;
optional bool use_lite = 18;
optional bool use_xpu = 19;
optional bool use_gpu = 20;
optional bool combined_model = 21;
optional bool encrypted_model = 22;
optional bool gpu_multi_stream = 23;
optional SparseParamServiceType sparse_param_service_type = 11;
optional string sparse_param_service_table_name = 12;
optional bool enable_memory_optimization = 13;
optional bool enable_ir_optimization = 14;
optional bool use_trt = 15;
optional bool use_lite = 16;
optional bool use_xpu = 17;
optional bool use_gpu = 18;
optional bool combined_model = 19;
optional bool encrypted_model = 20;
optional bool gpu_multi_stream = 21;
/*
* "runtime_thread_num": n == 0 means don`t use Asynchronous task scheduling
* mode.
* n > 0 means how many Predictor for this engine in Asynchronous task
* scheduling mode.
* "batch_infer_size": the max batch for this engine in Asynchronous task
* scheduling mode.
* "enable_overrun": always put a whole task into the TaskQueue even if the
* total batch is bigger than "batch_infer_size".
* "allow_split_request": allow to split task(which is corresponding to
* request).
*/
optional int32 runtime_thread_num = 30 [ default = 0 ];
optional int32 batch_infer_size = 31 [ default = 32 ];
optional bool enable_overrun = 32 [ default = false ];
optional bool allow_split_request = 33 [ default = true ];
};
// model_toolkit conf
......
......@@ -324,7 +324,8 @@ bool TaskExecutor<TaskT>::move_task_to_batch(
}
if (rem <= 0) break;
}
LOG(INFO) << "Number of tasks remaining in _task_queue is"
<< _task_queue.size();
return true;
}
......
......@@ -289,6 +289,7 @@ class Client(object):
log_id=0):
self.profile_.record('py_prepro_0')
# fetch 可以为空,此时会取所有的输出结果
if feed is None:
raise ValueError("You should specify feed for prediction")
......@@ -297,6 +298,7 @@ class Client(object):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
# fetch 可以为空,此时会取所有的输出结果
elif fetch == None:
pass
else:
......@@ -442,6 +444,7 @@ class Client(object):
model_engine_names = result_batch_handle.get_engine_names()
for mi, engine_name in enumerate(model_engine_names):
result_map = {}
# fetch 为空,则会取所有的输出结果
if len(fetch_names) == 0:
fetch_names = result_batch_handle.get_tensor_alias_names(mi)
# result map needs to be a numpy array
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册