提交 73992ce3 编写于 作者: J JiabinYang

add train_id and trainer_num to select reader

上级 d02b745b
......@@ -139,10 +139,14 @@ def infer_during_train(args):
current_list = os.listdir(args.model_output_dir)
# logger.info("current_list is : {}".format(current_list))
# logger.info("model_file_list is : {}".format(model_file_list))
solved_new = True
if set(model_file_list) == set(current_list):
logger.info("No New models created")
if solved_new:
solved_new = False
logger.info("No New models created")
pass
else:
solved_new = True
increment_models = list()
for f in current_list:
if f not in model_file_list:
......
......@@ -11,7 +11,13 @@ logger.setLevel(logging.INFO)
class Word2VecReader(object):
def __init__(self, dict_path, data_path, filelist, window_size=5):
def __init__(self,
dict_path,
data_path,
filelist,
trainer_id,
trainer_num,
window_size=5):
self.window_size_ = window_size
self.data_path_ = data_path
self.filelist = filelist
......@@ -20,6 +26,8 @@ class Word2VecReader(object):
self.id_to_word = dict()
self.word_to_path = dict()
self.word_to_code = dict()
self.trainer_id = trainer_id
self.trainer_num = trainer_num
word_all_count = 0
word_counts = []
......@@ -81,40 +89,50 @@ class Word2VecReader(object):
with open(self.data_path_ + "/" + file, 'r') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
for line in f:
line = preprocess.text_strip(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx, self.window_size_)
for context_id in context_word_ids:
yield [target_id], [context_id]
if self.trainer_id == count % self.trainer_num:
line = preprocess.text_strip(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx, self.window_size_)
for context_id in context_word_ids:
yield [target_id], [context_id]
else:
pass
count += 1
def _reader_hs():
for file in self.filelist:
with open(self.data_path_ + "/" + file, 'r') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
for line in f:
line = preprocess.text_strip(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx, self.window_size_)
for context_id in context_word_ids:
yield [target_id], [context_id], [
self.word_to_code[self.id_to_word[
context_id]]
], [
self.word_to_path[self.id_to_word[
context_id]]
]
if self.trainer_id == count % self.trainer_num:
line = preprocess.text_strip(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx, self.window_size_)
for context_id in context_word_ids:
yield [target_id], [context_id], [
self.word_to_code[self.id_to_word[
context_id]]
], [
self.word_to_path[self.id_to_word[
context_id]]
]
else:
pass
count += 1
if not with_hs:
return _reader
......
......@@ -203,7 +203,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
batch_id)
inference_test(global_scope(), model_dir, args)
if batch_id % 1000000 == 0 and batch_id != 0:
if batch_id % 500000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(
batch_id)
fluid.io.save_persistables(executor=exe, dirname=model_dir)
......@@ -234,8 +234,16 @@ def train(args):
os.mkdir(args.model_output_dir)
filelist = GetFileList(args.train_data_path)
word2vec_reader = reader.Word2VecReader(args.dict_path,
args.train_data_path, filelist)
word2vec_reader = None
if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1":
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_path, filelist, 0, 1)
else:
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
word2vec_reader = reader.Word2VecReader(args.dict_path,
args.train_data_path, filelist,
trainer_id, trainer_num)
logger.info("dict_size: {}".format(word2vec_reader.dict_size))
loss, py_reader = skip_gram_word2vec(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册