提交 0adeee11 编写于 作者: L Liufang Sang 提交者: whs

[PaddleSlim] change normal program to compile program to get infer time (#3712)

上级 4a1803d1
...@@ -168,21 +168,23 @@ def main(): ...@@ -168,21 +168,23 @@ def main():
imid2path = reader.imid2path imid2path = reader.imid2path
keys = ['bbox'] keys = ['bbox']
infer_time = True infer_time = True
compile_prog = fluid.compiler.CompiledProgram(infer_prog)
for iter_id, data in enumerate(reader()): for iter_id, data in enumerate(reader()):
feed_data = [[d[0], d[1]] for d in data] feed_data = [[d[0], d[1]] for d in data]
# for infer time # for infer time
if infer_time: if infer_time:
warmup_times = 10 warmup_times = 10
repeats_time = 30 repeats_time = 100
feed_data_dict = feeder.feed(feed_data); feed_data_dict = feeder.feed(feed_data);
for i in range(warmup_times): for i in range(warmup_times):
exe.run(infer_prog, exe.run(compile_prog,
feed=feed_data_dict, feed=feed_data_dict,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
start_time = time.time() start_time = time.time()
for i in range(repeats_time): for i in range(repeats_time):
exe.run(infer_prog, exe.run(compile_prog,
feed=feed_data_dict, feed=feed_data_dict,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
...@@ -190,7 +192,7 @@ def main(): ...@@ -190,7 +192,7 @@ def main():
print("infer time: {} ms/sample".format((time.time()-start_time) * 1000 / repeats_time)) print("infer time: {} ms/sample".format((time.time()-start_time) * 1000 / repeats_time))
infer_time = False infer_time = False
outs = exe.run(infer_prog, outs = exe.run(compile_prog,
feed=feeder.feed(feed_data), feed=feeder.feed(feed_data),
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
......
...@@ -50,25 +50,26 @@ def infer(args): ...@@ -50,25 +50,26 @@ def infer(args):
results=[] results=[]
#for infer time, if you don't need, please change infer_time to False #for infer time, if you don't need, please change infer_time to False
infer_time = True infer_time = True
compile_prog = fluid.compiler.CompiledProgram(test_program)
for batch_id, data in enumerate(test_reader()): for batch_id, data in enumerate(test_reader()):
# for infer time # for infer time
if infer_time: if infer_time:
warmup_times = 10 warmup_times = 10
repeats_time = 30 repeats_time = 100
feed_data = feeder.feed(data) feed_data = feeder.feed(data)
for i in range(warmup_times): for i in range(warmup_times):
exe.run(test_program, exe.run(compile_prog,
feed=feed_data, feed=feed_data,
fetch_list=fetch_targets) fetch_list=fetch_targets)
start_time = time.time() start_time = time.time()
for i in range(repeats_time): for i in range(repeats_time):
exe.run(test_program, exe.run(compile_prog,
feed=feed_data, feed=feed_data,
fetch_list=fetch_targets) fetch_list=fetch_targets)
print("infer time: {} ms/sample".format((time.time()-start_time) * 1000 / repeats_time)) print("infer time: {} ms/sample".format((time.time()-start_time) * 1000 / repeats_time))
infer_time = False infer_time = False
# top1_acc, top5_acc # top1_acc, top5_acc
result = exe.run(test_program, result = exe.run(compile_prog,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=fetch_targets) fetch_list=fetch_targets)
result = np.array(result[0]) result = np.array(result[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册