提交 da47bc10 编写于 作者: B barrierye

catch timeout exception when first gpu predict

上级 31177ce7
# gRPC临时接口
# gRPC接口
gRPC 接口实现形式类似 Web Service:
......
......@@ -42,7 +42,8 @@ for ei in range(10000):
fetch_map = client.predict(feed=feed_dict, fetch=["prob"])
prob_list.append(fetch_map['prob'][0][1])
label_list.append(data[0][-1][0])
break
print(prob_list)
print(auc(label_list, prob_list))
end = time.time()
print(end - start)
......@@ -40,9 +40,16 @@ def call_back(call_future, data):
task_count = 0
for data in test_reader():
future = client.predict(feed={"x": data[0][0]}, fetch=["price"], asyn=True)
task_count += 1
future.add_done_callback(functools.partial(call_back, data=data))
try:
future = client.predict(
feed={"x": data[0][0]}, fetch=["price"], asyn=True)
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
task_count += 1
future.add_done_callback(functools.partial(call_back, data=data))
while complete_task_count[0] != task_count:
time.sleep(0.1)
......@@ -27,5 +27,11 @@ test_reader = paddle.batch(
for data in test_reader():
batch_feed = [{"x": x[0]} for x in data]
fetch_map = client.predict(feed=batch_feed, fetch=["price"])
print(fetch_map)
try:
fetch_map = client.predict(feed=batch_feed, fetch=["price"])
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print(fetch_map)
......@@ -25,6 +25,12 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
fetch_map = client.predict(
feed={"x": data[0][0]}, fetch=["price"], is_python=False)
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
try:
fetch_map = client.predict(
feed={"x": data[0][0]}, fetch=["price"], is_python=False)
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......@@ -25,5 +25,12 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0].tolist()}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
try:
fetch_map = client.predict(
feed={"x": data[0][0].tolist()}, fetch=["price"])
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......@@ -33,5 +33,5 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1])
server.set_gpuid(0)
server.prepare_server(workdir="work_dir1", port=9393, device="gpu")
server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server()
......@@ -25,5 +25,11 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
try:
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
except grpc.RpcError as e:
status_code = e.code()
if grpc.StatusCode.DEADLINE_EXCEEDED == status_code:
print('timeout')
else:
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册