未验证 提交 4d9ed16a 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #378 from MRXLT/0.2.0-fix-gpu-v2

bug fix
...@@ -53,7 +53,7 @@ def single_func(idx, resource): ...@@ -53,7 +53,7 @@ def single_func(idx, resource):
feed_batch.append(reader.process(dataset[bi])) feed_batch.append(reader.process(dataset[bi]))
b_end = time.time() b_end = time.time()
if profile_flags: if profile_flags:
print("PROFILE\tpid:{}\tbert+pre_0:{} bert_pre_1:{}".format( print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
os.getpid(), os.getpid(),
int(round(b_start * 1000000)), int(round(b_start * 1000000)),
int(round(b_end * 1000000)))) int(round(b_end * 1000000))))
...@@ -69,9 +69,7 @@ def single_func(idx, resource): ...@@ -69,9 +69,7 @@ def single_func(idx, resource):
if __name__ == '__main__': if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner() multi_thread_runner = MultiThreadRunner()
endpoint_list = [ endpoint_list = ["127.0.0.1:9292"]
"127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
]
result = multi_thread_runner.run(single_func, args.thread, result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list}) {"endpoint": endpoint_list})
avg_cost = 0 avg_cost = 0
......
...@@ -30,7 +30,10 @@ def predict(image_path, server): ...@@ -30,7 +30,10 @@ def predict(image_path, server):
req = json.dumps({"image": image, "fetch": ["score"]}) req = json.dumps({"image": image, "fetch": ["score"]})
r = requests.post( r = requests.post(
server, data=req, headers={"Content-Type": "application/json"}) server, data=req, headers={"Content-Type": "application/json"})
try:
print(r.json()["score"][0]) print(r.json()["score"][0])
except ValueError:
print(r.text)
return r return r
......
...@@ -32,7 +32,7 @@ def save_model(server_model_folder, ...@@ -32,7 +32,7 @@ def save_model(server_model_folder,
executor = Executor(place=CPUPlace()) executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict] feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
target_vars = fetch_var_dict.values() target_vars = list(fetch_var_dict.values())
save_inference_model( save_inference_model(
server_model_folder, server_model_folder,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册