未验证 提交 b9e19be4 编写于 作者: Y Yibing Liu 提交者: GitHub

Replace pyreader by data loader (#3651)

上级 48f01f52
...@@ -41,7 +41,8 @@ def create_model(args, bert_config, num_labels, is_prediction=False): ...@@ -41,7 +41,8 @@ def create_model(args, bert_config, num_labels, is_prediction=False):
] ]
(src_ids, pos_ids, sent_ids, input_mask, labels) = inputs (src_ids, pos_ids, sent_ids, input_mask, labels) = inputs
pyreader = fluid.io.PyReader(feed_list=inputs, capacity=50, iterable=False) data_loader = fluid.io.DataLoader.from_generator(
feed_list=inputs, capacity=50, iterable=False)
bert = BertModel( bert = BertModel(
src_ids=src_ids, src_ids=src_ids,
...@@ -71,7 +72,7 @@ def create_model(args, bert_config, num_labels, is_prediction=False): ...@@ -71,7 +72,7 @@ def create_model(args, bert_config, num_labels, is_prediction=False):
feed_targets_name = [ feed_targets_name = [
src_ids.name, pos_ids.name, sent_ids.name, input_mask.name src_ids.name, pos_ids.name, sent_ids.name, input_mask.name
] ]
return pyreader, probs, feed_targets_name return data_loader, probs, feed_targets_name
logits = fluid.layers.reshape(logits, [-1, num_labels], inplace=True) logits = fluid.layers.reshape(logits, [-1, num_labels], inplace=True)
ce_loss, probs = fluid.layers.softmax_with_cross_entropy( ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
...@@ -81,4 +82,4 @@ def create_model(args, bert_config, num_labels, is_prediction=False): ...@@ -81,4 +82,4 @@ def create_model(args, bert_config, num_labels, is_prediction=False):
num_seqs = fluid.layers.create_tensor(dtype='int64') num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(input=probs, label=labels, total=num_seqs) accuracy = fluid.layers.accuracy(input=probs, label=labels, total=num_seqs)
return pyreader, loss, probs, accuracy, num_seqs return data_loader, loss, probs, accuracy, num_seqs
...@@ -17,6 +17,12 @@ from __future__ import absolute_import ...@@ -17,6 +17,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import sys
if six.PY2:
reload(sys)
sys.setdefaultencoding('utf8')
import os import os
import time import time
import argparse import argparse
...@@ -82,7 +88,7 @@ def main(args): ...@@ -82,7 +88,7 @@ def main(args):
predict_startup = fluid.Program() predict_startup = fluid.Program()
with fluid.program_guard(predict_prog, predict_startup): with fluid.program_guard(predict_prog, predict_startup):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
predict_pyreader, probs, feed_target_names = create_model( predict_data_loader, probs, feed_target_names = create_model(
args, args,
bert_config=bert_config, bert_config=bert_config,
num_labels=num_labels, num_labels=num_labels,
...@@ -112,11 +118,11 @@ def main(args): ...@@ -112,11 +118,11 @@ def main(args):
predict_exe = fluid.ParallelExecutor( predict_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda, main_program=predict_prog) use_cuda=args.use_cuda, main_program=predict_prog)
predict_pyreader.decorate_batch_generator( predict_data_loader.set_batch_generator(
processor.data_generator( processor.data_generator(
batch_size=args.batch_size, phase='test', epoch=1, shuffle=False)) batch_size=args.batch_size, phase='test', epoch=1, shuffle=False))
predict_pyreader.start() predict_data_loader.start()
all_results = [] all_results = []
time_begin = time.time() time_begin = time.time()
while True: while True:
...@@ -124,7 +130,7 @@ def main(args): ...@@ -124,7 +130,7 @@ def main(args):
results = predict_exe.run(fetch_list=[probs.name]) results = predict_exe.run(fetch_list=[probs.name])
all_results.extend(results[0]) all_results.extend(results[0])
except fluid.core.EOFException: except fluid.core.EOFException:
predict_pyreader.reset() predict_data_loader.reset()
break break
time_end = time.time() time_end = time.time()
......
...@@ -17,9 +17,11 @@ from __future__ import absolute_import ...@@ -17,9 +17,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import sys import sys
reload(sys) if six.PY2:
sys.setdefaultencoding('utf8') reload(sys)
sys.setdefaultencoding('utf8')
import os import os
import time import time
...@@ -107,8 +109,8 @@ args = parser.parse_args() ...@@ -107,8 +109,8 @@ args = parser.parse_args()
# yapf: enable. # yapf: enable.
def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase): def evaluate(exe, test_program, test_data_loader, fetch_list, eval_phase):
test_pyreader.start() test_data_loader.start()
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time() time_begin = time.time()
while True: while True:
...@@ -119,7 +121,7 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase): ...@@ -119,7 +121,7 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase):
total_acc.extend(np_acc * np_num_seqs) total_acc.extend(np_acc * np_num_seqs)
total_num_seqs.extend(np_num_seqs) total_num_seqs.extend(np_num_seqs)
except fluid.core.EOFException: except fluid.core.EOFException:
test_pyreader.reset() test_data_loader.reset()
break break
time_end = time.time() time_end = time.time()
print("[%s evaluation] ave loss: %f, ave acc: %f, elapsed time: %f s" % print("[%s evaluation] ave loss: %f, ave acc: %f, elapsed time: %f s" %
...@@ -203,7 +205,7 @@ def main(args): ...@@ -203,7 +205,7 @@ def main(args):
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, loss, probs, accuracy, num_seqs = create_model( train_data_loader, loss, probs, accuracy, num_seqs = create_model(
args, args,
bert_config=bert_config, bert_config=bert_config,
num_labels=num_labels) num_labels=num_labels)
...@@ -228,13 +230,13 @@ def main(args): ...@@ -228,13 +230,13 @@ def main(args):
dev_prog = fluid.Program() dev_prog = fluid.Program()
with fluid.program_guard(dev_prog, startup_prog): with fluid.program_guard(dev_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
dev_pyreader, loss, probs, accuracy, num_seqs = create_model( dev_data_loader, loss, probs, accuracy, num_seqs = create_model(
args, args,
bert_config=bert_config, bert_config=bert_config,
num_labels=num_labels) num_labels=num_labels)
dev_prog = dev_prog.clone(for_test=True) dev_prog = dev_prog.clone(for_test=True)
dev_pyreader.decorate_batch_generator( dev_data_loader.set_batch_generator(
processor.data_generator( processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size,
phase='dev', phase='dev',
...@@ -246,13 +248,13 @@ def main(args): ...@@ -246,13 +248,13 @@ def main(args):
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, loss, probs, accuracy, num_seqs = create_model( test_data_loader, loss, probs, accuracy, num_seqs = create_model(
args, args,
bert_config=bert_config, bert_config=bert_config,
num_labels=num_labels) num_labels=num_labels)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
test_pyreader.decorate_batch_generator( test_data_loader.set_batch_generator(
processor.data_generator( processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size,
phase='test', phase='test',
...@@ -305,11 +307,11 @@ def main(args): ...@@ -305,11 +307,11 @@ def main(args):
train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel( train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
train_pyreader.decorate_batch_generator(train_data_generator, place) train_data_loader.set_batch_generator(train_data_generator, place)
if args.do_train: if args.do_train:
train_pyreader.start() train_data_loader.start()
steps = 0 steps = 0
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time() time_begin = time.time()
...@@ -339,7 +341,7 @@ def main(args): ...@@ -339,7 +341,7 @@ def main(args):
total_num_seqs.extend(np_num_seqs) total_num_seqs.extend(np_num_seqs)
if args.verbose: if args.verbose:
verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size( verbose = "train data_loader queue size: %d, " % train_data_loader.queue.size(
) )
verbose += "learning rate: %f" % np_lr[0] verbose += "learning rate: %f" % np_lr[0]
if args.use_fp16: if args.use_fp16:
...@@ -375,18 +377,18 @@ def main(args): ...@@ -375,18 +377,18 @@ def main(args):
throughput = [] throughput = []
# evaluate dev set # evaluate dev set
if args.do_val: if args.do_val:
evaluate(exe, dev_prog, dev_pyreader, evaluate(exe, dev_prog, dev_data_loader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
"dev") "dev")
# evaluate test set # evaluate test set
if args.do_test: if args.do_test:
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_data_loader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
"test") "test")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
train_pyreader.reset() train_data_loader.reset()
break break
if args.enable_ce: if args.enable_ce:
card_num = get_cards() card_num = get_cards()
...@@ -410,13 +412,13 @@ def main(args): ...@@ -410,13 +412,13 @@ def main(args):
# final eval on dev set # final eval on dev set
if args.do_val: if args.do_val:
print("Final validation result:") print("Final validation result:")
evaluate(exe, dev_prog, dev_pyreader, evaluate(exe, dev_prog, dev_data_loader,
[loss.name, accuracy.name, num_seqs.name], "dev") [loss.name, accuracy.name, num_seqs.name], "dev")
# final eval on test set # final eval on test set
if args.do_test: if args.do_test:
print("Final test result:") print("Final test result:")
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_data_loader,
[loss.name, accuracy.name, num_seqs.name], "test") [loss.name, accuracy.name, num_seqs.name], "test")
......
...@@ -17,9 +17,11 @@ from __future__ import absolute_import ...@@ -17,9 +17,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import sys import sys
reload(sys) if six.PY2:
sys.setdefaultencoding('utf8') reload(sys)
sys.setdefaultencoding('utf8')
import argparse import argparse
import collections import collections
...@@ -129,7 +131,7 @@ def create_model(bert_config, is_training=False): ...@@ -129,7 +131,7 @@ def create_model(bert_config, is_training=False):
dtype=input_fields['dtypes'][i], dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i]) for i in range(len(input_fields['names']))] lod_level=input_fields['lod_levels'][i]) for i in range(len(input_fields['names']))]
pyreader = fluid.io.PyReader(feed_list=inputs, capacity=50, iterable=False) data_loader = fluid.io.DataLoader.from_generator(feed_list=inputs, capacity=50, iterable=False)
if is_training: if is_training:
(src_ids, pos_ids, sent_ids, input_mask, start_positions, end_positions) = inputs (src_ids, pos_ids, sent_ids, input_mask, start_positions, end_positions) = inputs
...@@ -174,23 +176,23 @@ def create_model(bert_config, is_training=False): ...@@ -174,23 +176,23 @@ def create_model(bert_config, is_training=False):
start_loss = compute_loss(start_logits, start_positions) start_loss = compute_loss(start_logits, start_positions)
end_loss = compute_loss(end_logits, end_positions) end_loss = compute_loss(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2.0 total_loss = (start_loss + end_loss) / 2.0
return pyreader, total_loss, num_seqs return data_loader, total_loss, num_seqs
else: else:
return pyreader, unique_id, start_logits, end_logits, num_seqs return data_loader, unique_id, start_logits, end_logits, num_seqs
RawResult = collections.namedtuple("RawResult", RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"]) ["unique_id", "start_logits", "end_logits"])
def predict(test_exe, test_program, test_pyreader, fetch_list, processor): def predict(test_exe, test_program, test_data_loader, fetch_list, processor):
if not os.path.exists(args.checkpoints): if not os.path.exists(args.checkpoints):
os.makedirs(args.checkpoints) os.makedirs(args.checkpoints)
output_prediction_file = os.path.join(args.checkpoints, "predictions.json") output_prediction_file = os.path.join(args.checkpoints, "predictions.json")
output_nbest_file = os.path.join(args.checkpoints, "nbest_predictions.json") output_nbest_file = os.path.join(args.checkpoints, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(args.checkpoints, "null_odds.json") output_null_log_odds_file = os.path.join(args.checkpoints, "null_odds.json")
test_pyreader.start() test_data_loader.start()
all_results = [] all_results = []
time_begin = time.time() time_begin = time.time()
while True: while True:
...@@ -209,7 +211,7 @@ def predict(test_exe, test_program, test_pyreader, fetch_list, processor): ...@@ -209,7 +211,7 @@ def predict(test_exe, test_program, test_pyreader, fetch_list, processor):
start_logits=start_logits, start_logits=start_logits,
end_logits=end_logits)) end_logits=end_logits))
except fluid.core.EOFException: except fluid.core.EOFException:
test_pyreader.reset() test_data_loader.reset()
break break
time_end = time.time() time_end = time.time()
...@@ -277,7 +279,7 @@ def train(args): ...@@ -277,7 +279,7 @@ def train(args):
train_program = fluid.Program() train_program = fluid.Program()
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, loss, num_seqs = create_model( train_data_loader, loss, num_seqs = create_model(
bert_config=bert_config, bert_config=bert_config,
is_training=True) is_training=True)
...@@ -302,7 +304,7 @@ def train(args): ...@@ -302,7 +304,7 @@ def train(args):
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, unique_ids, start_logits, end_logits, num_seqs = create_model( test_data_loader, unique_ids, start_logits, end_logits, num_seqs = create_model(
bert_config=bert_config, bert_config=bert_config,
is_training=False) is_training=False)
...@@ -346,9 +348,9 @@ def train(args): ...@@ -346,9 +348,9 @@ def train(args):
train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel( train_compiled_program = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, exec_strategy=exec_strategy) loss_name=loss.name, exec_strategy=exec_strategy)
train_pyreader.decorate_batch_generator(train_data_generator, place) train_data_loader.set_batch_generator(train_data_generator, place)
train_pyreader.start() train_data_loader.start()
steps = 0 steps = 0
total_cost, total_num_seqs = [], [] total_cost, total_num_seqs = [], []
time_begin = time.time() time_begin = time.time()
...@@ -374,7 +376,7 @@ def train(args): ...@@ -374,7 +376,7 @@ def train(args):
total_num_seqs.extend(np_num_seqs) total_num_seqs.extend(np_num_seqs)
if args.verbose: if args.verbose:
verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size( verbose = "train data_loader queue size: %d, " % train_data_loader.queue.size(
) )
verbose += "learning rate: %f " % np_lr[0] verbose += "learning rate: %f " % np_lr[0]
if args.use_fp16: if args.use_fp16:
...@@ -401,11 +403,11 @@ def train(args): ...@@ -401,11 +403,11 @@ def train(args):
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps) + "_final") "step_" + str(steps) + "_final")
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
train_pyreader.reset() train_data_loader.reset()
break break
if args.do_predict: if args.do_predict:
test_pyreader.decorate_batch_generator( test_data_loader.set_batch_generator(
processor.data_generator( processor.data_generator(
data_path=args.predict_file, data_path=args.predict_file,
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -414,7 +416,7 @@ def train(args): ...@@ -414,7 +416,7 @@ def train(args):
dev_count=1, dev_count=1,
epoch=1), place) epoch=1), place)
predict(exe, test_prog, test_pyreader, [ predict(exe, test_prog, test_data_loader, [
unique_ids.name, start_logits.name, end_logits.name, num_seqs.name unique_ids.name, start_logits.name, end_logits.name, num_seqs.name
], processor) ], processor)
......
...@@ -17,13 +17,14 @@ from __future__ import absolute_import ...@@ -17,13 +17,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import sys import sys
reload(sys) if six.PY2:
sys.setdefaultencoding('utf8') reload(sys)
sys.setdefaultencoding('utf8')
import os import os
import time import time
import sys
import argparse import argparse
import numpy as np import numpy as np
import multiprocessing import multiprocessing
...@@ -111,7 +112,7 @@ def create_model(bert_config): ...@@ -111,7 +112,7 @@ def create_model(bert_config):
(src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels) = inputs (src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels) = inputs
pyreader = fluid.io.PyReader(feed_list=inputs, capacity=50, iterable=False) data_loader = fluid.io.DataLoader.from_generator(feed_list=inputs, capacity=50, iterable=False)
bert = BertModel( bert = BertModel(
src_ids=src_ids, src_ids=src_ids,
...@@ -125,14 +126,14 @@ def create_model(bert_config): ...@@ -125,14 +126,14 @@ def create_model(bert_config):
next_sent_acc, mask_lm_loss, total_loss = bert.get_pretraining_output( next_sent_acc, mask_lm_loss, total_loss = bert.get_pretraining_output(
mask_label, mask_pos, labels) mask_label, mask_pos, labels)
return pyreader, next_sent_acc, mask_lm_loss, total_loss return data_loader, next_sent_acc, mask_lm_loss, total_loss
def predict_wrapper(args, def predict_wrapper(args,
exe, exe,
bert_config, bert_config,
test_prog=None, test_prog=None,
pyreader=None, data_loader=None,
fetch_list=None): fetch_list=None):
# Context to do validation. # Context to do validation.
data_path = args.test_set_dir if args.do_test else args.validation_set_dir data_path = args.test_set_dir if args.do_test else args.validation_set_dir
...@@ -147,7 +148,7 @@ def predict_wrapper(args, ...@@ -147,7 +148,7 @@ def predict_wrapper(args,
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
is_test=True) is_test=True)
pyreader.decorate_batch_generator(data_reader.data_generator()) data_loader.set_batch_generator(data_reader.data_generator())
if args.do_test: if args.do_test:
assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \ assert args.init_checkpoint is not None, "[FATAL] Please use --init_checkpoint '/path/to/checkpoints' \
...@@ -155,8 +156,8 @@ def predict_wrapper(args, ...@@ -155,8 +156,8 @@ def predict_wrapper(args,
init_pretraining_params(exe, args.init_checkpoint, test_prog) init_pretraining_params(exe, args.init_checkpoint, test_prog)
def predict(exe=exe, pyreader=pyreader): def predict(exe=exe, data_loader=data_loader):
pyreader.start() data_loader.start()
cost = 0 cost = 0
lm_cost = 0 lm_cost = 0
...@@ -175,7 +176,7 @@ def predict_wrapper(args, ...@@ -175,7 +176,7 @@ def predict_wrapper(args,
print("[test_set] steps: %d" % steps) print("[test_set] steps: %d" % steps)
except fluid.core.EOFException: except fluid.core.EOFException:
pyreader.reset() data_loader.reset()
break break
used_time = time.time() - time_begin used_time = time.time() - time_begin
...@@ -192,7 +193,7 @@ def test(args): ...@@ -192,7 +193,7 @@ def test(args):
test_startup = fluid.Program() test_startup = fluid.Program()
with fluid.program_guard(test_prog, test_startup): with fluid.program_guard(test_prog, test_startup):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model( test_data_loader, next_sent_acc, mask_lm_loss, total_loss = create_model(
bert_config=bert_config) bert_config=bert_config)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
...@@ -206,7 +207,7 @@ def test(args): ...@@ -206,7 +207,7 @@ def test(args):
exe, exe,
bert_config, bert_config,
test_prog=test_prog, test_prog=test_prog,
pyreader=test_pyreader, data_loader=test_data_loader,
fetch_list=[next_sent_acc.name, mask_lm_loss.name, total_loss.name]) fetch_list=[next_sent_acc.name, mask_lm_loss.name, total_loss.name])
print("test begin") print("test begin")
...@@ -227,7 +228,7 @@ def train(args): ...@@ -227,7 +228,7 @@ def train(args):
startup_prog = fluid.Program() startup_prog = fluid.Program()
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model( train_data_loader, next_sent_acc, mask_lm_loss, total_loss = create_model(
bert_config=bert_config) bert_config=bert_config)
scheduled_lr, loss_scaling = optimization( scheduled_lr, loss_scaling = optimization(
loss=total_loss, loss=total_loss,
...@@ -249,7 +250,7 @@ def train(args): ...@@ -249,7 +250,7 @@ def train(args):
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, next_sent_acc, mask_lm_loss, total_loss = create_model( test_data_loader, next_sent_acc, mask_lm_loss, total_loss = create_model(
bert_config=bert_config) bert_config=bert_config)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
...@@ -334,13 +335,13 @@ def train(args): ...@@ -334,13 +335,13 @@ def train(args):
exe, exe,
bert_config, bert_config,
test_prog=test_prog, test_prog=test_prog,
pyreader=test_pyreader, data_loader=test_data_loader,
fetch_list=[ fetch_list=[
next_sent_acc.name, mask_lm_loss.name, total_loss.name next_sent_acc.name, mask_lm_loss.name, total_loss.name
]) ])
train_pyreader.decorate_batch_generator(data_reader.data_generator()) train_data_loader.set_batch_generator(data_reader.data_generator())
train_pyreader.start() train_data_loader.start()
steps = 0 steps = 0
cost = [] cost = []
lm_cost = [] lm_cost = []
...@@ -391,7 +392,7 @@ def train(args): ...@@ -391,7 +392,7 @@ def train(args):
epoch, current_file_index, total_file, current_file = data_reader.get_progress( epoch, current_file_index, total_file, current_file = data_reader.get_progress(
) )
if args.verbose: if args.verbose:
verbose = "feed_queue size: %d, " %train_pyreader.queue.size() verbose = "feed_queue size: %d, " %train_data_loader.queue.size()
verbose += "current learning_rate: %f, " % np_lr[0] verbose += "current learning_rate: %f, " % np_lr[0]
if args.use_fp16: if args.use_fp16:
verbose += "loss scaling: %f" % np_scaling[0] verbose += "loss scaling: %f" % np_scaling[0]
...@@ -426,7 +427,7 @@ def train(args): ...@@ -426,7 +427,7 @@ def train(args):
np.mean(np.array(vali_acc) / vali_steps), vali_speed)) np.mean(np.array(vali_acc) / vali_steps), vali_speed))
except fluid.core.EOFException: except fluid.core.EOFException:
train_pyreader.reset() train_data_loader.reset()
break break
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -119,9 +119,10 @@ def create_master_params_grads(params_grads, main_prog, startup_prog, ...@@ -119,9 +119,10 @@ def create_master_params_grads(params_grads, main_prog, startup_prog,
def master_param_to_train_param(master_params_grads, params_grads, main_prog): def master_param_to_train_param(master_params_grads, params_grads, main_prog):
for idx, m_p_g in enumerate(master_params_grads): for idx, m_p_g in enumerate(master_params_grads):
train_p, _ = params_grads[idx] train_p, _ = params_grads[idx]
if train_p.name.find("layer_norm") > -1:
continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
if train_p.name.find("layer_norm") > -1:
fluid.layers.assign(m_p_g[0], train_p)
else:
append_cast_op(m_p_g[0], train_p, main_prog) append_cast_op(m_p_g[0], train_p, main_prog)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册