提交 0e34948a 编写于 作者: B barrierye

fix bug

上级 b0a0d108
...@@ -49,6 +49,8 @@ class ModelRes { ...@@ -49,6 +49,8 @@ class ModelRes {
res._int64_value_map.end()); res._int64_value_map.end());
_float_value_map.insert(res._float_value_map.begin(), _float_value_map.insert(res._float_value_map.begin(),
res._float_value_map.end()); res._float_value_map.end());
_shape_map.insert(res._shape_map.begin(), res._shape_map.end());
_lod_map.insert(res._lod_map.begin(), res._lod_map.end());
} }
ModelRes(ModelRes&& res) { ModelRes(ModelRes&& res) {
_engine_name = std::move(res._engine_name); _engine_name = std::move(res._engine_name);
...@@ -58,6 +60,10 @@ class ModelRes { ...@@ -58,6 +60,10 @@ class ModelRes {
_float_value_map.insert( _float_value_map.insert(
std::make_move_iterator(std::begin(res._float_value_map)), std::make_move_iterator(std::begin(res._float_value_map)),
std::make_move_iterator(std::end(res._float_value_map))); std::make_move_iterator(std::end(res._float_value_map)));
_shape_map.insert(std::make_move_iterator(std::begin(res._shape_map)),
std::make_move_iterator(std::end(res._shape_map)));
_lod_map.insert(std::make_move_iterator(std::begin(res._lod_map)),
std::make_move_iterator(std::end(res._lod_map)));
} }
~ModelRes() {} ~ModelRes() {}
const std::vector<int64_t>& get_int64_by_name(const std::string& name) { const std::vector<int64_t>& get_int64_by_name(const std::string& name) {
...@@ -85,6 +91,10 @@ class ModelRes { ...@@ -85,6 +91,10 @@ class ModelRes {
_float_value_map.insert( _float_value_map.insert(
std::make_move_iterator(std::begin(res._float_value_map)), std::make_move_iterator(std::begin(res._float_value_map)),
std::make_move_iterator(std::end(res._float_value_map))); std::make_move_iterator(std::end(res._float_value_map)));
_shape_map.insert(std::make_move_iterator(std::begin(res._shape_map)),
std::make_move_iterator(std::end(res._shape_map)));
_lod_map.insert(std::make_move_iterator(std::begin(res._lod_map)),
std::make_move_iterator(std::end(res._lod_map)));
} }
return *this; return *this;
} }
......
...@@ -100,11 +100,11 @@ for i in range(3): ...@@ -100,11 +100,11 @@ for i in range(3):
fetch = ["acc", "cost", "prediction"] fetch = ["acc", "cost", "prediction"]
fetch_maps = client.predict(feed=feed, fetch=fetch) fetch_maps = client.predict(feed=feed, fetch=fetch)
if len(fetch_maps) == 1: if len(fetch_maps) == 1:
print("step: {}, res: {}".format(i, fetch_maps['prediction'][1])) print("step: {}, res: {}".format(i, fetch_maps['prediction'][0][1]))
else: else:
for model, fetch_map in fetch_maps.items(): for model, fetch_map in fetch_maps.items():
print("step: {}, model: {}, res: {}".format(i, model, fetch_map[ print("step: {}, model: {}, res: {}".format(i, model, fetch_map[
'prediction'][1])) 'prediction'][0][1]))
``` ```
Compared with the normal prediction service, the client side has not changed much. When multiple model predictions are used, the prediction service will return a dictionary with engine name `engine_name`(the value is defined on the server side) as the key, and the corresponding model prediction results as the value. Compared with the normal prediction service, the client side has not changed much. When multiple model predictions are used, the prediction service will return a dictionary with engine name `engine_name`(the value is defined on the server side) as the key, and the corresponding model prediction results as the value.
......
...@@ -100,11 +100,11 @@ for i in range(3): ...@@ -100,11 +100,11 @@ for i in range(3):
fetch = ["acc", "cost", "prediction"] fetch = ["acc", "cost", "prediction"]
fetch_maps = client.predict(feed=feed, fetch=fetch) fetch_maps = client.predict(feed=feed, fetch=fetch)
if len(fetch_maps) == 1: if len(fetch_maps) == 1:
print("step: {}, res: {}".format(i, fetch_maps['prediction'][1])) print("step: {}, res: {}".format(i, fetch_maps['prediction'][0][1]))
else: else:
for model, fetch_map in fetch_maps.items(): for model, fetch_map in fetch_maps.items():
print("step: {}, model: {}, res: {}".format(i, model, fetch_map[ print("step: {}, model: {}, res: {}".format(i, model, fetch_map[
'prediction'][1])) 'prediction'][0][1]))
``` ```
Client端与普通预测服务没有发生太大的变化。当使用多个模型预测时,预测服务将返回一个key为Server端定义的引擎名称`engine_name`,value为对应的模型预测结果的字典。 Client端与普通预测服务没有发生太大的变化。当使用多个模型预测时,预测服务将返回一个key为Server端定义的引擎名称`engine_name`,value为对应的模型预测结果的字典。
......
...@@ -35,8 +35,8 @@ for i in range(3): ...@@ -35,8 +35,8 @@ for i in range(3):
fetch = ["acc", "cost", "prediction"] fetch = ["acc", "cost", "prediction"]
fetch_maps = client.predict(feed=feed, fetch=fetch) fetch_maps = client.predict(feed=feed, fetch=fetch)
if len(fetch_maps) == 1: if len(fetch_maps) == 1:
print("step: {}, res: {}".format(i, fetch_maps['prediction'][1])) print("step: {}, res: {}".format(i, fetch_maps['prediction'][0][1]))
else: else:
for model, fetch_map in fetch_maps.items(): for model, fetch_map in fetch_maps.items():
print("step: {}, model: {}, res: {}".format(i, model, fetch_map[ print("step: {}, model: {}, res: {}".format(i, model, fetch_map[
'prediction'][1])) 'prediction'][0][1]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册