From febbb0f8f202083c612515539a402b54581e1d4f Mon Sep 17 00:00:00 2001 From: xuezhong Date: Wed, 31 Oct 2018 12:34:06 +0000 Subject: [PATCH] add ce --- fluid/machine_reading_comprehension/args.py | 4 ++++ .../machine_reading_comprehension/data/download.sh | 1 + fluid/machine_reading_comprehension/dataset.py | 2 +- fluid/machine_reading_comprehension/rc_model.py | 1 + fluid/machine_reading_comprehension/run.py | 14 ++++++++++++++ fluid/machine_reading_comprehension/run.sh | 2 +- 6 files changed, 22 insertions(+), 2 deletions(-) diff --git a/fluid/machine_reading_comprehension/args.py b/fluid/machine_reading_comprehension/args.py index e37aad9a..53812252 100644 --- a/fluid/machine_reading_comprehension/args.py +++ b/fluid/machine_reading_comprehension/args.py @@ -120,5 +120,9 @@ def parse_args(): '--result_name', default='test_result', help='the file name of the results') + parser.add_argument( + "--enable_ce", + action='store_true', + help="If set, run the task with continuous evaluation logs.") args = parser.parse_args() return args diff --git a/fluid/machine_reading_comprehension/data/download.sh b/fluid/machine_reading_comprehension/data/download.sh index 41f79dd0..bcba3c7e 100644 --- a/fluid/machine_reading_comprehension/data/download.sh +++ b/fluid/machine_reading_comprehension/data/download.sh @@ -21,6 +21,7 @@ if [[ -d preprocessed ]] && [[ -d raw ]]; then exit 0 else wget -c --no-check-certificate http://dureader.gz.bcebos.com/dureader_preprocessed.zip + wget -c --no-check-certificate http://dureader.gz.bcebos.com/demo.tgz fi if md5sum --status -c md5sum.txt; then diff --git a/fluid/machine_reading_comprehension/dataset.py b/fluid/machine_reading_comprehension/dataset.py index 7e50b7e0..3aaf87be 100644 --- a/fluid/machine_reading_comprehension/dataset.py +++ b/fluid/machine_reading_comprehension/dataset.py @@ -152,7 +152,7 @@ class BRCDataset(object): batch_data['passage_token_ids'].append(passage_token_ids) batch_data['passage_length'].append( min(len(passage_token_ids), self.max_p_len)) - # record the start passage index of current doc + # record the start passage index of current sample passade_idx_offset = sum(batch_data['passage_num']) batch_data['passage_num'].append(count) gold_passage_offset = 0 diff --git a/fluid/machine_reading_comprehension/rc_model.py b/fluid/machine_reading_comprehension/rc_model.py index 932ccd9c..d95fa9d7 100644 --- a/fluid/machine_reading_comprehension/rc_model.py +++ b/fluid/machine_reading_comprehension/rc_model.py @@ -317,4 +317,5 @@ def rc_model(hidden_size, vocab, args): cost.persistable = True feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"] + layers.Print(ms, message='ms', summarize=3) return cost, start_probs, end_probs, ms, feeding_list diff --git a/fluid/machine_reading_comprehension/run.py b/fluid/machine_reading_comprehension/run.py index 0ab05b90..5959709b 100644 --- a/fluid/machine_reading_comprehension/run.py +++ b/fluid/machine_reading_comprehension/run.py @@ -236,7 +236,11 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order, total_loss += np.array(val_fetch_outs[0]).sum() start_probs_m = LodTensor_Array(val_fetch_outs[1]) end_probs_m = LodTensor_Array(val_fetch_outs[2]) + for data in feed_data: + data_len = [[len(y) for y in x[3]] for x in data] + logger.info(str(data_len)) match_lod = val_fetch_outs[3].lod() + logger.info(str(match_lod)) count += len(np.array(val_fetch_outs[0])) n_batch_cnt += len(np.array(val_fetch_outs[0])) @@ -413,6 +417,8 @@ def train(logger, args): n_batch_loss += cost_train total_loss += cost_train * args.batch_size * dev_count + if args.enable_ce and batch_id >= 100: + break if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0: print_para(main_program, parallel_executor, logger, args) @@ -457,6 +463,14 @@ def train(logger, args): executor=exe, dirname=model_path, main_program=main_program) + if args.enable_ce: # For CE + print("kpis\ttrain_cost_card%d\t%f" % + (dev_count, total_loss / total_num)) + if brc_data.dev_set is not None: + print("kpis\ttest_cost_card%d\t%f" % + (dev_count, eval_loss)) + print("kpis\ttrain_duration_card%d\t%f" % + (dev_count, time_consumed)) def evaluate(logger, args): diff --git a/fluid/machine_reading_comprehension/run.sh b/fluid/machine_reading_comprehension/run.sh index a8241d05..c3e2d8c2 100644 --- a/fluid/machine_reading_comprehension/run.sh +++ b/fluid/machine_reading_comprehension/run.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=2 python run.py \ --trainset 'data/preprocessed/trainset/search.train.json' \ 'data/preprocessed/trainset/zhidao.train.json' \ -- GitLab