提交 58e085f9 编写于 作者: M MRXLT

fix demo for py3

上级 604ae6c5
...@@ -17,10 +17,16 @@ import base64 ...@@ -17,10 +17,16 @@ import base64
import json import json
import time import time
import os import os
import sys
py_version = sys.version_info[0]
def predict(image_path, server): def predict(image_path, server):
if py_version == 2:
image = base64.b64encode(open(image_path).read()) image = base64.b64encode(open(image_path).read())
else:
image = base64.b64encode(open(image_path, "rb").read()).decode("utf-8")
req = json.dumps({"image": image, "fetch": ["score"]}) req = json.dumps({"image": image, "fetch": ["score"]})
r = requests.post( r = requests.post(
server, data=req, headers={"Content-Type": "application/json"}) server, data=req, headers={"Content-Type": "application/json"})
...@@ -28,15 +34,6 @@ def predict(image_path, server): ...@@ -28,15 +34,6 @@ def predict(image_path, server):
return r return r
def batch_predict(image_path, server):
image = base64.b64encode(open(image_path).read())
req = json.dumps({"image": [image, image], "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
print(r.json()["result"][1]["score"][0])
return r
if __name__ == "__main__": if __name__ == "__main__":
server = "http://127.0.0.1:9393/image/prediction" server = "http://127.0.0.1:9393/image/prediction"
image_list = os.listdir("./image_data/n01440764/") image_list = os.listdir("./image_data/n01440764/")
......
...@@ -19,16 +19,15 @@ import time ...@@ -19,16 +19,15 @@ import time
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9295"]) client.connect(["127.0.0.1:9393"])
reader = ImageReader() reader = ImageReader()
start = time.time() start = time.time()
for i in range(1000): for i in range(1000):
with open("./data/n01440764_10026.JPEG") as f: with open("./data/n01440764_10026.JPEG", "rb") as f:
img = f.read() img = f.read()
img = reader.process_image(img).reshape(-1) img = reader.process_image(img).reshape(-1)
fetch_map = client.predict(feed={"image": img}, fetch=["score"]) fetch_map = client.predict(feed={"image": img}, fetch=["score"])
print(i)
end = time.time() end = time.time()
print(end - start) print(end - start)
......
...@@ -19,15 +19,23 @@ import paddle ...@@ -19,15 +19,23 @@ import paddle
import re import re
import paddle.fluid.incubate.data_generator as dg import paddle.fluid.incubate.data_generator as dg
py_version = sys.version_info[0]
class IMDBDataset(dg.MultiSlotDataGenerator): class IMDBDataset(dg.MultiSlotDataGenerator):
def load_resource(self, dictfile): def load_resource(self, dictfile):
self._vocab = {} self._vocab = {}
wid = 0 wid = 0
if py_version == 2:
with open(dictfile) as f: with open(dictfile) as f:
for line in f: for line in f:
self._vocab[line.strip()] = wid self._vocab[line.strip()] = wid
wid += 1 wid += 1
else:
with open(dictfile, encoding="utf-8") as f:
for line in f:
self._vocab[line.strip()] = wid
wid += 1
self._unk_id = len(self._vocab) self._unk_id = len(self._vocab)
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))') self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0]) self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册