提交 febbb0f8 编写于 作者: X xuezhong

add ce

上级 a01821e7
......@@ -120,5 +120,9 @@ def parse_args():
'--result_name',
default='test_result',
help='the file name of the results')
parser.add_argument(
"--enable_ce",
action='store_true',
help="If set, run the task with continuous evaluation logs.")
args = parser.parse_args()
return args
......@@ -21,6 +21,7 @@ if [[ -d preprocessed ]] && [[ -d raw ]]; then
exit 0
else
wget -c --no-check-certificate http://dureader.gz.bcebos.com/dureader_preprocessed.zip
wget -c --no-check-certificate http://dureader.gz.bcebos.com/demo.tgz
fi
if md5sum --status -c md5sum.txt; then
......
......@@ -152,7 +152,7 @@ class BRCDataset(object):
batch_data['passage_token_ids'].append(passage_token_ids)
batch_data['passage_length'].append(
min(len(passage_token_ids), self.max_p_len))
# record the start passage index of current doc
# record the start passage index of current sample
passade_idx_offset = sum(batch_data['passage_num'])
batch_data['passage_num'].append(count)
gold_passage_offset = 0
......
......@@ -317,4 +317,5 @@ def rc_model(hidden_size, vocab, args):
cost.persistable = True
feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"]
layers.Print(ms, message='ms', summarize=3)
return cost, start_probs, end_probs, ms, feeding_list
......@@ -236,7 +236,11 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
total_loss += np.array(val_fetch_outs[0]).sum()
start_probs_m = LodTensor_Array(val_fetch_outs[1])
end_probs_m = LodTensor_Array(val_fetch_outs[2])
for data in feed_data:
data_len = [[len(y) for y in x[3]] for x in data]
logger.info(str(data_len))
match_lod = val_fetch_outs[3].lod()
logger.info(str(match_lod))
count += len(np.array(val_fetch_outs[0]))
n_batch_cnt += len(np.array(val_fetch_outs[0]))
......@@ -413,6 +417,8 @@ def train(logger, args):
n_batch_loss += cost_train
total_loss += cost_train * args.batch_size * dev_count
if args.enable_ce and batch_id >= 100:
break
if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
print_para(main_program, parallel_executor, logger,
args)
......@@ -457,6 +463,14 @@ def train(logger, args):
executor=exe,
dirname=model_path,
main_program=main_program)
if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" %
(dev_count, total_loss / total_num))
if brc_data.dev_set is not None:
print("kpis\ttest_cost_card%d\t%f" %
(dev_count, eval_loss))
print("kpis\ttrain_duration_card%d\t%f" %
(dev_count, time_consumed))
def evaluate(logger, args):
......
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=2
python run.py \
--trainset 'data/preprocessed/trainset/search.train.json' \
'data/preprocessed/trainset/zhidao.train.json' \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册