提交 27a4de2f 编写于 作者: G guru4elephant

fix result map bug

上级 92c21d6e
......@@ -77,19 +77,6 @@ PYBIND11_MODULE(serving_client, m) {
fetch_name,
predict_res);
})
.def("predict",
[](PredictorClient &self,
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
return self.predict(float_feed,
float_feed_name,
int_feed,
int_feed_name,
fetch_name);
})
.def("batch_predict",
[](PredictorClient &self,
const std::vector<std::vector<std::vector<float>>>
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!flask/bin/python
from plugin_service import PluginService
import sys
class IMDBService(PluginService):
def prepare_service(self, args={}):
if len(args) == 0:
exit(-1)
self.word_dict = {}
with open(args["dict_file_path"]) as fin:
idx = 0
for line in fin:
self.word_dict[idx] = idx
idx += 1
def preprocess(self, feed={}, fetch=[]):
if "words" not in feed:
exit(-1)
res_feed = {}
res_feed["words"] = [self.word_dict[int(x)] for x in feed["words"]]
print(res_feed)
return res_feed, fetch
imdb_service = IMDBService(name="imdb", model=sys.argv[1], port=9898)
imdb_service.prepare_service({"dict_file_path":sys.argv[2]})
imdb_service.start_service()
......@@ -127,6 +127,7 @@ class Client(object):
predictor_sdk = SDKConfig()
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
print(sdk_desc)
self.client_handle_.create_predictor_by_desc(
sdk_desc.SerializeToString())
......@@ -161,6 +162,7 @@ class Client(object):
float_slot, float_feed_names, int_slot,
int_feed_names, fetch_names, self.result_handle_)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[0]
......@@ -213,3 +215,4 @@ class Client(object):
def release(self):
self.client_handle_.destroy_predictor()
self.client_handle_ = None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册