提交 ebda2052 编写于 作者: Y Yu Yang

Refine Reader

上级 8c32619e
...@@ -117,7 +117,6 @@ def skip_gram_word2vec(dict_size, ...@@ -117,7 +117,6 @@ def skip_gram_word2vec(dict_size,
cost = cost_hs cost = cost_hs
if with_nce and with_hsigmoid: if with_nce and with_hsigmoid:
cost = fluid.layers.elementwise_add(cost_nce, cost_hs) cost = fluid.layers.elementwise_add(cost_nce, cost_hs)
avg_cost = fluid.layers.reduce_mean(cost) avg_cost = fluid.layers.reduce_mean(cost)
return avg_cost, py_reader return avg_cost, py_reader
...@@ -31,9 +31,9 @@ def parse_args(): ...@@ -31,9 +31,9 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
pattern = re.compile("[^a-z] ")
def text_strip(text): def text_strip(text):
return re.sub("[^a-z ]", "", text) return pattern.sub("", text)
def build_Huffman(word_count, max_code_length): def build_Huffman(word_count, max_code_length):
......
...@@ -10,6 +10,23 @@ logger = logging.getLogger("fluid") ...@@ -10,6 +10,23 @@ logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class NumpyRandomInt(object):
def __init__(self, a, b, buf_size=1000):
self.idx = 0
self.buffer = np.random.random_integers(a, b, buf_size)
self.a = a
self.b = b
def __call__(self):
if self.idx == len(self.buffer):
self.buffer = np.random.random_integers(self.a, self.b, len(self.buffer))
self.idx = 0
result = self.buffer[self.idx]
self.idx += 1
return result
class Word2VecReader(object): class Word2VecReader(object):
def __init__(self, def __init__(self,
dict_path, dict_path,
...@@ -37,7 +54,7 @@ class Word2VecReader(object): ...@@ -37,7 +54,7 @@ class Word2VecReader(object):
for line in f: for line in f:
word, count = line.split()[0], int(line.split()[1]) word, count = line.split()[0], int(line.split()[1])
self.word_to_id_[word] = word_id self.word_to_id_[word] = word_id
self.id_to_word[word_id] = word #build id to word dict self.id_to_word[word_id] = word # build id to word dict
word_id += 1 word_id += 1
word_counts.append(count) word_counts.append(count)
word_all_count += count word_all_count += count
...@@ -67,7 +84,9 @@ class Word2VecReader(object): ...@@ -67,7 +84,9 @@ class Word2VecReader(object):
line.split(':')[1], dtype=int, sep=' ') line.split(':')[1], dtype=int, sep=' ')
print("word_pcode dict_size = " + str(len(self.word_to_code))) print("word_pcode dict_size = " + str(len(self.word_to_code)))
def get_context_words(self, words, idx, window_size): self.random_generator = NumpyRandomInt(1, self.window_size_ + 1)
def get_context_words(self, words, idx):
""" """
Get the context word list of target word. Get the context word list of target word.
...@@ -75,13 +94,15 @@ class Word2VecReader(object): ...@@ -75,13 +94,15 @@ class Word2VecReader(object):
idx: input word index idx: input word index
window_size: window size window_size: window size
""" """
target_window = np.random.randint(1, window_size + 1) target_window = self.random_generator()
# need to keep in mind that maybe there are no enough words before the target word. # need to keep in mind that maybe there are no enough words before the target word.
start_point = idx - target_window if (idx - target_window) > 0 else 0 start_point = idx - target_window # if (idx - target_window) > 0 else 0
if start_point < 0:
start_point = 0
end_point = idx + target_window end_point = idx + target_window
# context words of the target word # context words of the target word
targets = set(words[start_point:idx] + words[idx + 1:end_point + 1]) targets = words[start_point:idx] + words[idx + 1:end_point + 1]
return list(targets) return set(targets)
def train(self, with_hs): def train(self, with_hs):
def _reader(): def _reader():
...@@ -98,10 +119,10 @@ class Word2VecReader(object): ...@@ -98,10 +119,10 @@ class Word2VecReader(object):
if word in self.word_to_id_ if word in self.word_to_id_
] ]
for idx, target_id in enumerate(word_ids): for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words( context_word_ids = self.get_context_words(word_ids, idx)
word_ids, idx, self.window_size_)
for context_id in context_word_ids: for context_id in context_word_ids:
yield [target_id], [context_id] yield [target_id], [context_id]
else: else:
pass pass
count += 1 count += 1
...@@ -120,16 +141,15 @@ class Word2VecReader(object): ...@@ -120,16 +141,15 @@ class Word2VecReader(object):
if word in self.word_to_id_ if word in self.word_to_id_
] ]
for idx, target_id in enumerate(word_ids): for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words( context_word_ids = self.get_context_words(word_ids, idx)
word_ids, idx, self.window_size_)
for context_id in context_word_ids: for context_id in context_word_ids:
yield [target_id], [context_id], [ yield [target_id], [context_id], [
self.word_to_code[self.id_to_word[ self.word_to_code[self.id_to_word[
context_id]] context_id]]
], [ ], [
self.word_to_path[self.id_to_word[ self.word_to_path[self.id_to_word[
context_id]] context_id]]
] ]
else: else:
pass pass
count += 1 count += 1
...@@ -142,13 +162,10 @@ class Word2VecReader(object): ...@@ -142,13 +162,10 @@ class Word2VecReader(object):
if __name__ == "__main__": if __name__ == "__main__":
window_size = 10 window_size = 10
reader = Word2VecReader("data/1-billion_dict",
reader = Word2VecReader("data/enwik9_dict", "data/enwik9", window_size) "data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/",
i = 0 ['news.en-00001-of-00100'],
for x, y in reader.train()(): trainer_id=0, trainer_num=1, window_size=5)
print("x: " + str(x)) # i = 0
print("y: " + str(y)) for x, y in reader.train(False)():
print("\n") pass
if i == 10:
exit(0)
i += 1
...@@ -4,11 +4,10 @@ import argparse ...@@ -4,11 +4,10 @@ import argparse
import logging import logging
import os import os
import time import time
import numpy as np import numpy as np
import six
# disable gpu training for this example # disable gpu training for this example
os.environ["CUDA_VISIBLE_DEVICES"] = "" # os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -49,7 +48,7 @@ def parse_args(): ...@@ -49,7 +48,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--num_passes', '--num_passes',
type=int, type=int,
default=10, default=1,
help="The number of passes to train (default: 10)") help="The number of passes to train (default: 10)")
parser.add_argument( parser.add_argument(
'--model_output_dir', '--model_output_dir',
...@@ -126,14 +125,35 @@ def parse_args(): ...@@ -126,14 +125,35 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def train_loop(args, train_program, reader, py_reader, loss, trainer_id): def convert_python_to_tensor(batch_size, sample_reader):
train_reader = paddle.batch( def __reader__():
paddle.reader.shuffle( result = [[], [], [], []]
reader.train((args.with_hs or (not args.with_nce))), for sample in sample_reader():
buf_size=args.batch_size * 100), for i, fea in enumerate(sample):
batch_size=args.batch_size) result[i].append(fea)
if len(result[0]) == batch_size:
tensor_result = []
for tensor in result:
t = fluid.Tensor()
dat = np.array(tensor, dtype='int64')
if len(dat.shape) > 2:
dat = dat.reshape((dat.shape[0], dat.shape[2]))
elif len(dat.shape) == 1:
dat = dat.reshape((-1, 1))
t.set(dat, fluid.CPUPlace())
tensor_result.append(t)
yield tensor_result
result = [[], [], [], []]
py_reader.decorate_paddle_reader(train_reader) return __reader__
def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
py_reader.decorate_tensor_provider(convert_python_to_tensor(
args.batch_size, reader.train((args.with_hs or (not args.with_nce)))))
# py_reader.decorate_paddle_reader(train_reader)
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -144,6 +164,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): ...@@ -144,6 +164,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
print("CPU_NUM:" + str(os.getenv("CPU_NUM"))) print("CPU_NUM:" + str(os.getenv("CPU_NUM")))
exec_strategy.num_threads = int(os.getenv("CPU_NUM")) exec_strategy.num_threads = int(os.getenv("CPU_NUM"))
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
if int(os.getenv("CPU_NUM")) > 1: if int(os.getenv("CPU_NUM")) > 1:
...@@ -156,43 +177,31 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): ...@@ -156,43 +177,31 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
profile_state = "CPU"
profiler_step = 0
profiler_step_start = 20
profiler_step_end = 30
for pass_id in range(args.num_passes): for pass_id in range(args.num_passes):
epoch_start = time.time()
py_reader.start() py_reader.start()
time.sleep(10) # wait reading data.
epoch_start = time.time()
batch_id = 0 batch_id = 0
start = time.clock() start = time.clock()
try: try:
while True: while True:
if profiler_step == profiler_step_start:
fluid.profiler.start_profiler(profile_state)
loss_val = train_exe.run(fetch_list=[loss.name]) loss_val = train_exe.run(fetch_list=[loss.name])
loss_val = np.mean(loss_val) loss_val = np.mean(loss_val)
if profiler_step == profiler_step_end:
fluid.profiler.stop_profiler('total', 'trainer_profile.log')
profiler_step += 1
else:
profiler_step += 1
if batch_id % 50 == 0: if batch_id % 50 == 0:
logger.info( logger.info(
"TRAIN --> pass: {} batch: {} loss: {} reader queue:{}". "TRAIN --> pass: {} batch: {} loss: {} reader queue:{}".
format(pass_id, batch_id, format(pass_id, batch_id,
loss_val.mean() / args.batch_size, loss_val,
py_reader.queue.size())) py_reader.queue.size()))
if batch_id == 1000:
exit(0)
if args.with_speed: if args.with_speed:
if batch_id % 1000 == 0 and batch_id != 0: if batch_id % 100 == 0 and batch_id != 0:
elapsed = (time.clock() - start) elapsed = (time.clock() - start)
start = time.clock() start = time.clock()
samples = 1001 * args.batch_size * int( samples = 101 * args.batch_size * int(
os.getenv("CPU_NUM")) os.getenv("CPU_NUM"))
logger.info("Time used: {}, Samples/Sec: {}".format( logger.info("Time used: {}, Samples/Sec: {}".format(
elapsed, samples / elapsed)) elapsed, samples / elapsed))
...@@ -229,11 +238,12 @@ def GetFileList(data_path): ...@@ -229,11 +238,12 @@ def GetFileList(data_path):
def train(args): def train(args):
print("I am ehre")
if not os.path.isdir(args.model_output_dir): if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir) os.mkdir(args.model_output_dir)
filelist = GetFileList(args.train_data_path) filelist = GetFileList(args.train_data_path)[:1]
print(filelist)
word2vec_reader = None word2vec_reader = None
if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1": if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1":
word2vec_reader = reader.Word2VecReader( word2vec_reader = reader.Word2VecReader(
...@@ -329,7 +339,7 @@ def env_declar(): ...@@ -329,7 +339,7 @@ def env_declar():
print("%30s %s \n" % (key, os.environ[key])) print("%30s %s \n" % (key, os.environ[key]))
if os.environ["TRAINING_ROLE"] == "PSERVER" or os.environ[ if os.environ["TRAINING_ROLE"] == "PSERVER" or os.environ[
"PADDLE_IS_LOCAL"] == "0": "PADDLE_IS_LOCAL"] == "0":
os.environ["PADDLE_TRAINING_ROLE"] = os.environ["TRAINING_ROLE"] os.environ["PADDLE_TRAINING_ROLE"] = os.environ["TRAINING_ROLE"]
os.environ["PADDLE_PSERVER_PORT"] = os.environ["PADDLE_PORT"] os.environ["PADDLE_PSERVER_PORT"] = os.environ["PADDLE_PORT"]
os.environ["PADDLE_PSERVER_IPS"] = os.environ["PADDLE_PSERVERS"] os.environ["PADDLE_PSERVER_IPS"] = os.environ["PADDLE_PSERVERS"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册