提交 25571110 编写于 作者: H hysunflower 提交者: Jinhua Liang

Add for bert models (#3976)

* add_for_bert_models

* update add_for_bert_models

* update for add_for_bert_models
上级 2e4465c6
......@@ -32,6 +32,7 @@ import multiprocessing
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
import reader.cls as reader
from model.bert import BertConfig
......@@ -93,6 +94,12 @@ data_g.add_arg("do_lower_case", bool, True,
data_g.add_arg("random_seed", int, 0, "Random seed.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
# NOTE:profiler args, used for benchmark
run_type_g.add_arg("profiler_path", str, './', "the profiler output file path. (used for benchmark)")
run_type_g.add_arg("is_profiler", int, 0, "the profiler switch. (used for benchmark)")
run_type_g.add_arg("max_iter", int, 0, "the max batch nums to train. (used for benchmark)")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("shuffle", bool, True, "")
......@@ -317,9 +324,17 @@ def main(args):
time_begin = time.time()
throughput = []
ce_info = []
total_batch_num=0 # used for benchmark
while True:
try:
steps += 1
total_batch_num += 1 # used for benchmark
if args.max_iter and total_batch_num == args.max_iter: # used for benchmark
return
if steps % args.skip_steps == 0:
if args.use_fp16:
fetch_list = [loss.name, accuracy.name, scheduled_lr.name, num_seqs.name, loss_scaling.name]
......@@ -353,6 +368,13 @@ def main(args):
time_end = time.time()
used_time = time_end - time_begin
# profiler tools
if args.is_profiler and current_epoch == 0 and steps == args.skip_steps:
profiler.start_profiler("All")
elif args.is_profiler and current_epoch == 0 and steps == args.skip_steps * 2:
profiler.stop_profiler("total", args.profiler_path)
return
log_record = "epoch: {}, progress: {}/{}, step: {}, ave loss: {}, ave acc: {}".format(
current_epoch, current_example, num_train_examples,
steps, np.sum(total_cost) / np.sum(total_num_seqs),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册