未验证 提交 9ee5d983 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #371 from MRXLT/0.2.0-fix-gpu

fix gpu_ids && benchmark scripts
......@@ -57,7 +57,7 @@ def single_func(idx, resource):
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = client.predict(feed_batch=feed_batch, fetch=fetch)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
......
......@@ -55,7 +55,7 @@ def single_func(idx, resource):
for i in range(1, 27):
feed_dict["sparse_{}".format(i - 1)] = data[0][i]
feed_batch.append(feed_dict)
result = client.predict(feed_batch=feed_batch, fetch=fetch)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
......
......@@ -56,8 +56,7 @@ def single_func(idx, resource):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][
i]
feed_batch.append(feed_dict)
result = client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
......
......@@ -50,7 +50,7 @@ def single_func(idx, resource):
img = reader.process_image(img_list[i])
img = img.reshape(-1)
feed_batch.append({"image": img})
result = client.predict(feed_batch=feed_batch, fetch=fetch)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
......
......@@ -42,8 +42,7 @@ def single_func(idx, resource):
for bi in range(args.batch_size):
word_ids, label = imdb_dataset.get_words_and_label(line)
feed_batch.append({"words": word_ids})
result = client.predict(
feed_batch=feed_batch, fetch=["prediction"])
result = client.predict(feed=feed_batch, fetch=["prediction"])
else:
print("unsupport batch size {}".format(args.batch_size))
......
......@@ -64,14 +64,22 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss
def start_multi_card(args): # pylint: disable=doc-string-missing
gpus = ""
if args.gpu_ids == "":
if "CUDA_VISIBLE_DEVICES" in os.environ:
gpus = os.environ["CUDA_VISIBLE_DEVICES"]
else:
gpus = []
gpus = []
else:
gpus = args.gpu_ids.split(",")
if "CUDA_VISIBLE_DEVICES" in os.environ:
env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
for ids in gpus:
if int(ids) >= len(env_gpus):
print(
" Max index of gpu_ids out of range, the number of CUDA_VISIBLE_DEVICES is {}.".
format(len(env_gpus)))
exit(-1)
else:
env_gpus = []
if len(gpus) <= 0:
start_gpu_card_model(-1, 0, args)
print("gpu_ids not set, going to run cpu service.")
start_gpu_card_model(-1, -1, args)
else:
gpu_processes = []
for i, gpu_id in enumerate(gpus):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册