提交 c2bd8581 编写于 作者: T TeslaZhao

fix bugs of bert examples for batching feeds

上级 4493dc04
......@@ -33,5 +33,5 @@ for line in sys.stdin:
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape((128, 1))
#print(feed_dict)
result = client.predict(feed=feed_dict, fetch=fetch)
result = client.predict(feed=feed_dict, fetch=fetch, batch=True)
print(result)
......@@ -29,13 +29,14 @@ class BertService(WebService):
def preprocess(self, feed=[], fetch=[]):
feed_res = []
is_batch = True
for ins in feed:
feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape(
(1, len(feed_dict[key]), 1))
(len(feed_dict[key]), 1))
feed_res.append(feed_dict)
return feed_res, fetch
return feed_res, fetch, is_batch
bert_service = BertService(name="bert")
......
......@@ -112,13 +112,14 @@ class WebService(object):
if "fetch" not in request.json:
abort(400)
try:
feed, fetch = self.preprocess(request.json["feed"],
request.json["fetch"])
feed, fetch, is_batch = self.preprocess(request.json["feed"],
request.json["fetch"])
if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"]
if len(feed) == 0:
raise ValueError("empty input")
fetch_map = self.client.predict(feed=feed, fetch=fetch, batch=True)
fetch_map = self.client.predict(
feed=feed, fetch=fetch, batch=is_batch)
result = self.postprocess(
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": result}
......@@ -188,7 +189,8 @@ class WebService(object):
def preprocess(self, feed=[], fetch=[]):
print("This API will be deprecated later. Please do not use it")
return feed, fetch
is_batch = True
return feed, fetch, is_batch
def postprocess(self, feed=[], fetch=[], fetch_map=None):
print("This API will be deprecated later. Please do not use it")
......
......@@ -167,13 +167,14 @@ class WebService(object):
if "fetch" not in request.json:
abort(400)
try:
feed, fetch = self.preprocess(request.json["feed"],
request.json["fetch"])
feed, fetch, is_batch = self.preprocess(request.json["feed"],
request.json["fetch"])
if isinstance(feed, dict) and "fetch" in feed:
del feed["fetch"]
if len(feed) == 0:
raise ValueError("empty input")
fetch_map = self.client.predict(feed=feed, fetch=fetch)
fetch_map = self.client.predict(
feed=feed, fetch=fetch, batch=is_batch)
result = self.postprocess(
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": result}
......@@ -249,7 +250,8 @@ class WebService(object):
def preprocess(self, feed=[], fetch=[]):
print("This API will be deprecated later. Please do not use it")
return feed, fetch
is_batch = True
return feed, fetch, is_batch
def postprocess(self, feed=[], fetch=[], fetch_map=None):
print("This API will be deprecated later. Please do not use it")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册