From ebda2052053368d8a667cdf2a9fabbe0ca9754c4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 18 Dec 2018 14:43:17 +0800 Subject: [PATCH] Refine Reader --- fluid/PaddleRec/word2vec/network_conf.py | 1 - fluid/PaddleRec/word2vec/preprocess.py | 4 +- fluid/PaddleRec/word2vec/reader.py | 63 +++++++++++------- fluid/PaddleRec/word2vec/train.py | 82 +++++++++++++----------- 4 files changed, 88 insertions(+), 62 deletions(-) diff --git a/fluid/PaddleRec/word2vec/network_conf.py b/fluid/PaddleRec/word2vec/network_conf.py index 5b8e9513..2c78063d 100644 --- a/fluid/PaddleRec/word2vec/network_conf.py +++ b/fluid/PaddleRec/word2vec/network_conf.py @@ -117,7 +117,6 @@ def skip_gram_word2vec(dict_size, cost = cost_hs if with_nce and with_hsigmoid: cost = fluid.layers.elementwise_add(cost_nce, cost_hs) - avg_cost = fluid.layers.reduce_mean(cost) return avg_cost, py_reader diff --git a/fluid/PaddleRec/word2vec/preprocess.py b/fluid/PaddleRec/word2vec/preprocess.py index cb8dd100..47e5eef7 100644 --- a/fluid/PaddleRec/word2vec/preprocess.py +++ b/fluid/PaddleRec/word2vec/preprocess.py @@ -31,9 +31,9 @@ def parse_args(): return parser.parse_args() - +pattern = re.compile("[^a-z] ") def text_strip(text): - return re.sub("[^a-z ]", "", text) + return pattern.sub("", text) def build_Huffman(word_count, max_code_length): diff --git a/fluid/PaddleRec/word2vec/reader.py b/fluid/PaddleRec/word2vec/reader.py index 1e6f15ea..f3d2a83b 100644 --- a/fluid/PaddleRec/word2vec/reader.py +++ b/fluid/PaddleRec/word2vec/reader.py @@ -10,6 +10,23 @@ logger = logging.getLogger("fluid") logger.setLevel(logging.INFO) +class NumpyRandomInt(object): + def __init__(self, a, b, buf_size=1000): + self.idx = 0 + self.buffer = np.random.random_integers(a, b, buf_size) + self.a = a + self.b = b + + def __call__(self): + if self.idx == len(self.buffer): + self.buffer = np.random.random_integers(self.a, self.b, len(self.buffer)) + self.idx = 0 + + result = self.buffer[self.idx] + self.idx += 1 + return result + + class Word2VecReader(object): def __init__(self, dict_path, @@ -37,7 +54,7 @@ class Word2VecReader(object): for line in f: word, count = line.split()[0], int(line.split()[1]) self.word_to_id_[word] = word_id - self.id_to_word[word_id] = word #build id to word dict + self.id_to_word[word_id] = word # build id to word dict word_id += 1 word_counts.append(count) word_all_count += count @@ -67,7 +84,9 @@ class Word2VecReader(object): line.split(':')[1], dtype=int, sep=' ') print("word_pcode dict_size = " + str(len(self.word_to_code))) - def get_context_words(self, words, idx, window_size): + self.random_generator = NumpyRandomInt(1, self.window_size_ + 1) + + def get_context_words(self, words, idx): """ Get the context word list of target word. @@ -75,13 +94,15 @@ class Word2VecReader(object): idx: input word index window_size: window size """ - target_window = np.random.randint(1, window_size + 1) + target_window = self.random_generator() # need to keep in mind that maybe there are no enough words before the target word. - start_point = idx - target_window if (idx - target_window) > 0 else 0 + start_point = idx - target_window # if (idx - target_window) > 0 else 0 + if start_point < 0: + start_point = 0 end_point = idx + target_window # context words of the target word - targets = set(words[start_point:idx] + words[idx + 1:end_point + 1]) - return list(targets) + targets = words[start_point:idx] + words[idx + 1:end_point + 1] + return set(targets) def train(self, with_hs): def _reader(): @@ -98,10 +119,10 @@ class Word2VecReader(object): if word in self.word_to_id_ ] for idx, target_id in enumerate(word_ids): - context_word_ids = self.get_context_words( - word_ids, idx, self.window_size_) + context_word_ids = self.get_context_words(word_ids, idx) for context_id in context_word_ids: yield [target_id], [context_id] + else: pass count += 1 @@ -120,16 +141,15 @@ class Word2VecReader(object): if word in self.word_to_id_ ] for idx, target_id in enumerate(word_ids): - context_word_ids = self.get_context_words( - word_ids, idx, self.window_size_) + context_word_ids = self.get_context_words(word_ids, idx) for context_id in context_word_ids: yield [target_id], [context_id], [ self.word_to_code[self.id_to_word[ context_id]] ], [ - self.word_to_path[self.id_to_word[ - context_id]] - ] + self.word_to_path[self.id_to_word[ + context_id]] + ] else: pass count += 1 @@ -142,13 +162,10 @@ class Word2VecReader(object): if __name__ == "__main__": window_size = 10 - - reader = Word2VecReader("data/enwik9_dict", "data/enwik9", window_size) - i = 0 - for x, y in reader.train()(): - print("x: " + str(x)) - print("y: " + str(y)) - print("\n") - if i == 10: - exit(0) - i += 1 + reader = Word2VecReader("data/1-billion_dict", + "data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/", + ['news.en-00001-of-00100'], + trainer_id=0, trainer_num=1, window_size=5) + # i = 0 + for x, y in reader.train(False)(): + pass diff --git a/fluid/PaddleRec/word2vec/train.py b/fluid/PaddleRec/word2vec/train.py index 85fa0efd..53053954 100644 --- a/fluid/PaddleRec/word2vec/train.py +++ b/fluid/PaddleRec/word2vec/train.py @@ -4,11 +4,10 @@ import argparse import logging import os import time - import numpy as np - +import six # disable gpu training for this example -os.environ["CUDA_VISIBLE_DEVICES"] = "" +# os.environ["CUDA_VISIBLE_DEVICES"] = "" import paddle import paddle.fluid as fluid @@ -49,7 +48,7 @@ def parse_args(): parser.add_argument( '--num_passes', type=int, - default=10, + default=1, help="The number of passes to train (default: 10)") parser.add_argument( '--model_output_dir', @@ -126,14 +125,35 @@ def parse_args(): return parser.parse_args() -def train_loop(args, train_program, reader, py_reader, loss, trainer_id): - train_reader = paddle.batch( - paddle.reader.shuffle( - reader.train((args.with_hs or (not args.with_nce))), - buf_size=args.batch_size * 100), - batch_size=args.batch_size) +def convert_python_to_tensor(batch_size, sample_reader): + def __reader__(): + result = [[], [], [], []] + for sample in sample_reader(): + for i, fea in enumerate(sample): + result[i].append(fea) + + if len(result[0]) == batch_size: + tensor_result = [] + for tensor in result: + t = fluid.Tensor() + dat = np.array(tensor, dtype='int64') + if len(dat.shape) > 2: + dat = dat.reshape((dat.shape[0], dat.shape[2])) + elif len(dat.shape) == 1: + dat = dat.reshape((-1, 1)) + t.set(dat, fluid.CPUPlace()) + + tensor_result.append(t) + yield tensor_result + result = [[], [], [], []] - py_reader.decorate_paddle_reader(train_reader) + return __reader__ + + +def train_loop(args, train_program, reader, py_reader, loss, trainer_id): + py_reader.decorate_tensor_provider(convert_python_to_tensor( + args.batch_size, reader.train((args.with_hs or (not args.with_nce))))) + # py_reader.decorate_paddle_reader(train_reader) place = fluid.CPUPlace() @@ -144,6 +164,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): print("CPU_NUM:" + str(os.getenv("CPU_NUM"))) exec_strategy.num_threads = int(os.getenv("CPU_NUM")) + exec_strategy.use_experimental_executor = True build_strategy = fluid.BuildStrategy() if int(os.getenv("CPU_NUM")) > 1: @@ -156,43 +177,31 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): build_strategy=build_strategy, exec_strategy=exec_strategy) - profile_state = "CPU" - profiler_step = 0 - profiler_step_start = 20 - profiler_step_end = 30 - for pass_id in range(args.num_passes): - epoch_start = time.time() py_reader.start() + time.sleep(10) # wait reading data. + epoch_start = time.time() batch_id = 0 start = time.clock() - + try: while True: - - if profiler_step == profiler_step_start: - fluid.profiler.start_profiler(profile_state) - loss_val = train_exe.run(fetch_list=[loss.name]) loss_val = np.mean(loss_val) - if profiler_step == profiler_step_end: - fluid.profiler.stop_profiler('total', 'trainer_profile.log') - profiler_step += 1 - else: - profiler_step += 1 - if batch_id % 50 == 0: logger.info( "TRAIN --> pass: {} batch: {} loss: {} reader queue:{}". - format(pass_id, batch_id, - loss_val.mean() / args.batch_size, - py_reader.queue.size())) + format(pass_id, batch_id, + loss_val, + py_reader.queue.size())) + if batch_id == 1000: + exit(0) if args.with_speed: - if batch_id % 1000 == 0 and batch_id != 0: + if batch_id % 100 == 0 and batch_id != 0: elapsed = (time.clock() - start) start = time.clock() - samples = 1001 * args.batch_size * int( + samples = 101 * args.batch_size * int( os.getenv("CPU_NUM")) logger.info("Time used: {}, Samples/Sec: {}".format( elapsed, samples / elapsed)) @@ -229,11 +238,12 @@ def GetFileList(data_path): def train(args): - + print("I am ehre") if not os.path.isdir(args.model_output_dir): os.mkdir(args.model_output_dir) - filelist = GetFileList(args.train_data_path) + filelist = GetFileList(args.train_data_path)[:1] + print(filelist) word2vec_reader = None if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1": word2vec_reader = reader.Word2VecReader( @@ -329,7 +339,7 @@ def env_declar(): print("%30s %s \n" % (key, os.environ[key])) if os.environ["TRAINING_ROLE"] == "PSERVER" or os.environ[ - "PADDLE_IS_LOCAL"] == "0": + "PADDLE_IS_LOCAL"] == "0": os.environ["PADDLE_TRAINING_ROLE"] = os.environ["TRAINING_ROLE"] os.environ["PADDLE_PSERVER_PORT"] = os.environ["PADDLE_PORT"] os.environ["PADDLE_PSERVER_IPS"] = os.environ["PADDLE_PSERVERS"] -- GitLab