未验证 提交 2ed85a39 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #81 from PaddlePaddle/improve_squad

Don't use persistable for fetched variables
......@@ -144,9 +144,6 @@ def create_model(pyreader_name, bert_config, is_training=False):
batch_ones = fluid.layers.fill_constant_batch_size_like(
input=start_logits, dtype='int64', shape=[1], value=1)
num_seqs = fluid.layers.reduce_sum(input=batch_ones)
num_seqs.persistable = True
start_logits.persistable = True
end_logits.persistable = True
if is_training:
......@@ -161,7 +158,6 @@ def create_model(pyreader_name, bert_config, is_training=False):
total_loss = (start_loss + end_loss) / 2.0
if args.use_fp16 and args.loss_scaling > 1.0:
total_loss = total_loss * args.loss_scaling
total_loss.persistable = True
return pyreader, total_loss, num_seqs
else:
......@@ -282,7 +278,7 @@ def train(args):
use_fp16=args.use_fp16,
loss_scaling=args.loss_scaling)
fluid.memory_optimize(train_program)
fluid.memory_optimize(train_program, skip_opt_set=[loss.name, num_seqs.name])
if args.verbose:
if args.in_tokens:
......@@ -304,7 +300,8 @@ def train(args):
bert_config=bert_config,
is_training=False)
fluid.memory_optimize(test_prog)
fluid.memory_optimize(test_prog, skip_opt_set=[unique_ids.name,
start_logits.name, end_logits.name, num_seqs.name])
test_prog = test_prog.clone(for_test=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册