提交 11ed966b 编写于 作者: H hysunflower 提交者: Jinhua Liang

add_for_transformer_model (#3980)

上级 25571110
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import utils.dist_utils as dist_utils import utils.dist_utils as dist_utils
from utils.input_field import InputField from utils.input_field import InputField
...@@ -250,12 +251,15 @@ def do_train(args): ...@@ -250,12 +251,15 @@ def do_train(args):
# start training # start training
step_idx = 0 step_idx = 0
total_batch_num = 0 # this is for benchmark
for pass_id in range(args.epoch): for pass_id in range(args.epoch):
pass_start_time = time.time() pass_start_time = time.time()
input_field.loader.start() input_field.loader.start()
batch_id = 0 batch_id = 0
while True: while True:
if args.max_iter and total_batch_num == args.max_iter: # this for benchmark
return
try: try:
outs = exe.run(compiled_train_prog, outs = exe.run(compiled_train_prog,
fetch_list=[sum_cost.name, token_num.name] fetch_list=[sum_cost.name, token_num.name]
...@@ -299,6 +303,14 @@ def do_train(args): ...@@ -299,6 +303,14 @@ def do_train(args):
batch_id += 1 batch_id += 1
step_idx += 1 step_idx += 1
total_batch_num = total_batch_num + 1 # this is for benchmark
# profiler tools for benchmark
if args.is_profiler and pass_id == 0 and batch_id == args.print_step:
profiler.start_profiler("All")
elif args.is_profiler and pass_id == 0 and batch_id == args.print_step + 5:
profiler.stop_profiler("total", args.profiler_path)
return
except fluid.core.EOFException: except fluid.core.EOFException:
input_field.loader.reset() input_field.loader.reset()
......
...@@ -198,6 +198,11 @@ class PDConfig(object): ...@@ -198,6 +198,11 @@ class PDConfig(object):
self.default_g.add_arg("do_save_inference_model", bool, False, self.default_g.add_arg("do_save_inference_model", bool, False,
"Whether to perform model saving for inference.") "Whether to perform model saving for inference.")
# NOTE: args for profiler
self.default_g.add_arg("is_profiler", int, 0, "the switch of profiler tools. (used for benchmark)")
self.default_g.add_arg("profiler_path", str, './', "the profiler output file path. (used for benchmark)")
self.default_g.add_arg("max_iter", int, 0, "the max train batch num.(used for benchmark)")
self.parser = parser self.parser = parser
if json_file != "": if json_file != "":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册