提交 da47bc10 编写于 作者: B barrierye

catch timeout exception when first gpu predict

上级 31177ce7
# gRPC临时接口 # gRPC接口
gRPC 接口实现形式类似 Web Service: gRPC 接口实现形式类似 Web Service:
......
...@@ -42,7 +42,8 @@ for ei in range(10000): ...@@ -42,7 +42,8 @@ for ei in range(10000):
fetch_map = client.predict(feed=feed_dict, fetch=["prob"]) fetch_map = client.predict(feed=feed_dict, fetch=["prob"])
prob_list.append(fetch_map['prob'][0][1]) prob_list.append(fetch_map['prob'][0][1])
label_list.append(data[0][-1][0]) label_list.append(data[0][-1][0])
break
print(prob_list)
print(auc(label_list, prob_list)) print(auc(label_list, prob_list))
end = time.time() end = time.time()
print(end - start) print(end - start)
...@@ -40,9 +40,16 @@ def call_back(call_future, data): ...@@ -40,9 +40,16 @@ def call_back(call_future, data):
task_count = 0 task_count = 0
for data in test_reader(): for data in test_reader():
future = client.predict(feed={"x": data[0][0]}, fetch=["price"], asyn=True) try:
task_count += 1 future = client.predict(
future.add_done_callback(functools.partial(call_back, data=data)) feed={"x": data[0][0]}, fetch=["price"], asyn=True)
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
task_count += 1
future.add_done_callback(functools.partial(call_back, data=data))
while complete_task_count[0] != task_count: while complete_task_count[0] != task_count:
time.sleep(0.1) time.sleep(0.1)
...@@ -27,5 +27,11 @@ test_reader = paddle.batch( ...@@ -27,5 +27,11 @@ test_reader = paddle.batch(
for data in test_reader(): for data in test_reader():
batch_feed = [{"x": x[0]} for x in data] batch_feed = [{"x": x[0]} for x in data]
fetch_map = client.predict(feed=batch_feed, fetch=["price"]) try:
print(fetch_map) fetch_map = client.predict(feed=batch_feed, fetch=["price"])
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print(fetch_map)
...@@ -25,6 +25,12 @@ test_reader = paddle.batch( ...@@ -25,6 +25,12 @@ test_reader = paddle.batch(
batch_size=1) batch_size=1)
for data in test_reader(): for data in test_reader():
fetch_map = client.predict( try:
feed={"x": data[0][0]}, fetch=["price"], is_python=False) fetch_map = client.predict(
print("{} {}".format(fetch_map["price"][0], data[0][1][0])) feed={"x": data[0][0]}, fetch=["price"], is_python=False)
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
...@@ -25,5 +25,12 @@ test_reader = paddle.batch( ...@@ -25,5 +25,12 @@ test_reader = paddle.batch(
batch_size=1) batch_size=1)
for data in test_reader(): for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0].tolist()}, fetch=["price"]) try:
print("{} {}".format(fetch_map["price"][0], data[0][1][0])) fetch_map = client.predict(
feed={"x": data[0][0].tolist()}, fetch=["price"])
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
...@@ -33,5 +33,5 @@ server = Server() ...@@ -33,5 +33,5 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1]) server.load_model_config(sys.argv[1])
server.set_gpuid(0) server.set_gpuid(0)
server.prepare_server(workdir="work_dir1", port=9393, device="gpu") server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server() server.run_server()
...@@ -25,5 +25,11 @@ test_reader = paddle.batch( ...@@ -25,5 +25,11 @@ test_reader = paddle.batch(
batch_size=1) batch_size=1)
for data in test_reader(): for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"]) try:
print("{} {}".format(fetch_map["price"][0], data[0][1][0])) fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册