提交 73992ce3 编写于 作者: J JiabinYang

add train_id and trainer_num to select reader

上级 d02b745b
...@@ -139,10 +139,14 @@ def infer_during_train(args): ...@@ -139,10 +139,14 @@ def infer_during_train(args):
current_list = os.listdir(args.model_output_dir) current_list = os.listdir(args.model_output_dir)
# logger.info("current_list is : {}".format(current_list)) # logger.info("current_list is : {}".format(current_list))
# logger.info("model_file_list is : {}".format(model_file_list)) # logger.info("model_file_list is : {}".format(model_file_list))
solved_new = True
if set(model_file_list) == set(current_list): if set(model_file_list) == set(current_list):
logger.info("No New models created") if solved_new:
solved_new = False
logger.info("No New models created")
pass pass
else: else:
solved_new = True
increment_models = list() increment_models = list()
for f in current_list: for f in current_list:
if f not in model_file_list: if f not in model_file_list:
......
...@@ -11,7 +11,13 @@ logger.setLevel(logging.INFO) ...@@ -11,7 +11,13 @@ logger.setLevel(logging.INFO)
class Word2VecReader(object): class Word2VecReader(object):
def __init__(self, dict_path, data_path, filelist, window_size=5): def __init__(self,
dict_path,
data_path,
filelist,
trainer_id,
trainer_num,
window_size=5):
self.window_size_ = window_size self.window_size_ = window_size
self.data_path_ = data_path self.data_path_ = data_path
self.filelist = filelist self.filelist = filelist
...@@ -20,6 +26,8 @@ class Word2VecReader(object): ...@@ -20,6 +26,8 @@ class Word2VecReader(object):
self.id_to_word = dict() self.id_to_word = dict()
self.word_to_path = dict() self.word_to_path = dict()
self.word_to_code = dict() self.word_to_code = dict()
self.trainer_id = trainer_id
self.trainer_num = trainer_num
word_all_count = 0 word_all_count = 0
word_counts = [] word_counts = []
...@@ -81,40 +89,50 @@ class Word2VecReader(object): ...@@ -81,40 +89,50 @@ class Word2VecReader(object):
with open(self.data_path_ + "/" + file, 'r') as f: with open(self.data_path_ + "/" + file, 'r') as f:
logger.info("running data in {}".format(self.data_path_ + logger.info("running data in {}".format(self.data_path_ +
"/" + file)) "/" + file))
count = 1
for line in f: for line in f:
line = preprocess.text_strip(line) if self.trainer_id == count % self.trainer_num:
word_ids = [ line = preprocess.text_strip(line)
self.word_to_id_[word] for word in line.split() word_ids = [
if word in self.word_to_id_ self.word_to_id_[word] for word in line.split()
] if word in self.word_to_id_
for idx, target_id in enumerate(word_ids): ]
context_word_ids = self.get_context_words( for idx, target_id in enumerate(word_ids):
word_ids, idx, self.window_size_) context_word_ids = self.get_context_words(
for context_id in context_word_ids: word_ids, idx, self.window_size_)
yield [target_id], [context_id] for context_id in context_word_ids:
yield [target_id], [context_id]
else:
pass
count += 1
def _reader_hs(): def _reader_hs():
for file in self.filelist: for file in self.filelist:
with open(self.data_path_ + "/" + file, 'r') as f: with open(self.data_path_ + "/" + file, 'r') as f:
logger.info("running data in {}".format(self.data_path_ + logger.info("running data in {}".format(self.data_path_ +
"/" + file)) "/" + file))
count = 1
for line in f: for line in f:
line = preprocess.text_strip(line) if self.trainer_id == count % self.trainer_num:
word_ids = [ line = preprocess.text_strip(line)
self.word_to_id_[word] for word in line.split() word_ids = [
if word in self.word_to_id_ self.word_to_id_[word] for word in line.split()
] if word in self.word_to_id_
for idx, target_id in enumerate(word_ids): ]
context_word_ids = self.get_context_words( for idx, target_id in enumerate(word_ids):
word_ids, idx, self.window_size_) context_word_ids = self.get_context_words(
for context_id in context_word_ids: word_ids, idx, self.window_size_)
yield [target_id], [context_id], [ for context_id in context_word_ids:
self.word_to_code[self.id_to_word[ yield [target_id], [context_id], [
context_id]] self.word_to_code[self.id_to_word[
], [ context_id]]
self.word_to_path[self.id_to_word[ ], [
context_id]] self.word_to_path[self.id_to_word[
] context_id]]
]
else:
pass
count += 1
if not with_hs: if not with_hs:
return _reader return _reader
......
...@@ -203,7 +203,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): ...@@ -203,7 +203,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
batch_id) batch_id)
inference_test(global_scope(), model_dir, args) inference_test(global_scope(), model_dir, args)
if batch_id % 1000000 == 0 and batch_id != 0: if batch_id % 500000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str( model_dir = args.model_output_dir + '/batch-' + str(
batch_id) batch_id)
fluid.io.save_persistables(executor=exe, dirname=model_dir) fluid.io.save_persistables(executor=exe, dirname=model_dir)
...@@ -234,8 +234,16 @@ def train(args): ...@@ -234,8 +234,16 @@ def train(args):
os.mkdir(args.model_output_dir) os.mkdir(args.model_output_dir)
filelist = GetFileList(args.train_data_path) filelist = GetFileList(args.train_data_path)
word2vec_reader = reader.Word2VecReader(args.dict_path, word2vec_reader = None
args.train_data_path, filelist) if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1":
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_path, filelist, 0, 1)
else:
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
word2vec_reader = reader.Word2VecReader(args.dict_path,
args.train_data_path, filelist,
trainer_id, trainer_num)
logger.info("dict_size: {}".format(word2vec_reader.dict_size)) logger.info("dict_size: {}".format(word2vec_reader.dict_size))
loss, py_reader = skip_gram_word2vec( loss, py_reader = skip_gram_word2vec(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册