From 73992ce3a66dbf51ecad315cbb0a9145bcea9f2b Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Wed, 5 Dec 2018 09:00:54 +0000 Subject: [PATCH] add train_id and trainer_num to select reader --- fluid/PaddleRec/word2vec/infer.py | 6 ++- fluid/PaddleRec/word2vec/reader.py | 72 +++++++++++++++++++----------- fluid/PaddleRec/word2vec/train.py | 14 ++++-- 3 files changed, 61 insertions(+), 31 deletions(-) diff --git a/fluid/PaddleRec/word2vec/infer.py b/fluid/PaddleRec/word2vec/infer.py index 04985fd9..76e04b15 100644 --- a/fluid/PaddleRec/word2vec/infer.py +++ b/fluid/PaddleRec/word2vec/infer.py @@ -139,10 +139,14 @@ def infer_during_train(args): current_list = os.listdir(args.model_output_dir) # logger.info("current_list is : {}".format(current_list)) # logger.info("model_file_list is : {}".format(model_file_list)) + solved_new = True if set(model_file_list) == set(current_list): - logger.info("No New models created") + if solved_new: + solved_new = False + logger.info("No New models created") pass else: + solved_new = True increment_models = list() for f in current_list: if f not in model_file_list: diff --git a/fluid/PaddleRec/word2vec/reader.py b/fluid/PaddleRec/word2vec/reader.py index 4c0259f1..1e6f15ea 100644 --- a/fluid/PaddleRec/word2vec/reader.py +++ b/fluid/PaddleRec/word2vec/reader.py @@ -11,7 +11,13 @@ logger.setLevel(logging.INFO) class Word2VecReader(object): - def __init__(self, dict_path, data_path, filelist, window_size=5): + def __init__(self, + dict_path, + data_path, + filelist, + trainer_id, + trainer_num, + window_size=5): self.window_size_ = window_size self.data_path_ = data_path self.filelist = filelist @@ -20,6 +26,8 @@ class Word2VecReader(object): self.id_to_word = dict() self.word_to_path = dict() self.word_to_code = dict() + self.trainer_id = trainer_id + self.trainer_num = trainer_num word_all_count = 0 word_counts = [] @@ -81,40 +89,50 @@ class Word2VecReader(object): with open(self.data_path_ + "/" + file, 'r') as f: logger.info("running data in {}".format(self.data_path_ + "/" + file)) + count = 1 for line in f: - line = preprocess.text_strip(line) - word_ids = [ - self.word_to_id_[word] for word in line.split() - if word in self.word_to_id_ - ] - for idx, target_id in enumerate(word_ids): - context_word_ids = self.get_context_words( - word_ids, idx, self.window_size_) - for context_id in context_word_ids: - yield [target_id], [context_id] + if self.trainer_id == count % self.trainer_num: + line = preprocess.text_strip(line) + word_ids = [ + self.word_to_id_[word] for word in line.split() + if word in self.word_to_id_ + ] + for idx, target_id in enumerate(word_ids): + context_word_ids = self.get_context_words( + word_ids, idx, self.window_size_) + for context_id in context_word_ids: + yield [target_id], [context_id] + else: + pass + count += 1 def _reader_hs(): for file in self.filelist: with open(self.data_path_ + "/" + file, 'r') as f: logger.info("running data in {}".format(self.data_path_ + "/" + file)) + count = 1 for line in f: - line = preprocess.text_strip(line) - word_ids = [ - self.word_to_id_[word] for word in line.split() - if word in self.word_to_id_ - ] - for idx, target_id in enumerate(word_ids): - context_word_ids = self.get_context_words( - word_ids, idx, self.window_size_) - for context_id in context_word_ids: - yield [target_id], [context_id], [ - self.word_to_code[self.id_to_word[ - context_id]] - ], [ - self.word_to_path[self.id_to_word[ - context_id]] - ] + if self.trainer_id == count % self.trainer_num: + line = preprocess.text_strip(line) + word_ids = [ + self.word_to_id_[word] for word in line.split() + if word in self.word_to_id_ + ] + for idx, target_id in enumerate(word_ids): + context_word_ids = self.get_context_words( + word_ids, idx, self.window_size_) + for context_id in context_word_ids: + yield [target_id], [context_id], [ + self.word_to_code[self.id_to_word[ + context_id]] + ], [ + self.word_to_path[self.id_to_word[ + context_id]] + ] + else: + pass + count += 1 if not with_hs: return _reader diff --git a/fluid/PaddleRec/word2vec/train.py b/fluid/PaddleRec/word2vec/train.py index 59c75c3c..85fa0efd 100644 --- a/fluid/PaddleRec/word2vec/train.py +++ b/fluid/PaddleRec/word2vec/train.py @@ -203,7 +203,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): batch_id) inference_test(global_scope(), model_dir, args) - if batch_id % 1000000 == 0 and batch_id != 0: + if batch_id % 500000 == 0 and batch_id != 0: model_dir = args.model_output_dir + '/batch-' + str( batch_id) fluid.io.save_persistables(executor=exe, dirname=model_dir) @@ -234,8 +234,16 @@ def train(args): os.mkdir(args.model_output_dir) filelist = GetFileList(args.train_data_path) - word2vec_reader = reader.Word2VecReader(args.dict_path, - args.train_data_path, filelist) + word2vec_reader = None + if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1": + word2vec_reader = reader.Word2VecReader( + args.dict_path, args.train_data_path, filelist, 0, 1) + else: + trainer_id = int(os.environ["PADDLE_TRAINER_ID"]) + trainers = int(os.environ["PADDLE_TRAINERS"]) + word2vec_reader = reader.Word2VecReader(args.dict_path, + args.train_data_path, filelist, + trainer_id, trainer_num) logger.info("dict_size: {}".format(word2vec_reader.dict_size)) loss, py_reader = skip_gram_word2vec( -- GitLab