From f67756ead8c11c9b62ace05708efb9c502eb7fe2 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 16 Nov 2018 11:57:24 +0000 Subject: [PATCH] test=develop merge nce part with hsigmoid part --- fluid/PaddleRec/word2vec/network_conf.py | 78 ++++++++++++++++---- fluid/PaddleRec/word2vec/preprocess.py | 94 ++++++++++++++++++++++++ fluid/PaddleRec/word2vec/reader.py | 51 ++++++++++++- fluid/PaddleRec/word2vec/train.py | 32 +++++++- 4 files changed, 232 insertions(+), 23 deletions(-) diff --git a/fluid/PaddleRec/word2vec/network_conf.py b/fluid/PaddleRec/word2vec/network_conf.py index 373681fe..2d74207c 100644 --- a/fluid/PaddleRec/word2vec/network_conf.py +++ b/fluid/PaddleRec/word2vec/network_conf.py @@ -22,8 +22,15 @@ import numpy as np import paddle.fluid as fluid -def skip_gram_word2vec(dict_size, word_frequencys, embedding_size): - def nce_layer(input, label, embedding_size, num_total_classes, num_neg_samples, sampler, custom_dist, sample_weight): + +def skip_gram_word2vec(dict_size, + word_frequencys, + embedding_size, + max_code_length=None, + with_hsigmoid=False, + with_nce=True): + def nce_layer(input, label, embedding_size, num_total_classes, + num_neg_samples, sampler, custom_dist, sample_weight): # convert word_frequencys to tensor nid_freq_arr = np.array(word_frequencys).astype('float32') nid_freq_var = fluid.layers.assign(input=nid_freq_arr) @@ -31,33 +38,72 @@ def skip_gram_word2vec(dict_size, word_frequencys, embedding_size): w_param_name = "nce_w" b_param_name = "nce_b" w_param = fluid.default_main_program().global_block().create_parameter( - shape=[num_total_classes, embedding_size], dtype='float32', name=w_param_name) + shape=[num_total_classes, embedding_size], + dtype='float32', + name=w_param_name) b_param = fluid.default_main_program().global_block().create_parameter( shape=[num_total_classes, 1], dtype='float32', name=b_param_name) - cost = fluid.layers.nce( - input=input, - label=label, - num_total_classes=num_total_classes, - sampler=sampler, - custom_dist=nid_freq_var, - sample_weight = sample_weight, - param_attr=fluid.ParamAttr(name=w_param_name), - bias_attr=fluid.ParamAttr(name=b_param_name), - num_neg_samples=num_neg_samples) + cost = fluid.layers.nce(input=input, + label=label, + num_total_classes=num_total_classes, + sampler=sampler, + custom_dist=nid_freq_var, + sample_weight=sample_weight, + param_attr=fluid.ParamAttr(name=w_param_name), + bias_attr=fluid.ParamAttr(name=b_param_name), + num_neg_samples=num_neg_samples) + + return cost + + def hsigmoid_layer(input, label, non_leaf_num, max_code_length, data_list): + hs_cost = None + ptable = None + pcode = None + if max_code_length != None: + ptable = fluid.layers.data( + name='ptable', shape=[max_code_length], dtype='int64') + pcode = fluid.layers.data( + name='pcode', shape=[max_code_length], dtype='int64') + data_list.append(pcode) + data_list.append(ptable) + else: + ptable = fluid.layers.data(name='ptable', shape=[40], dtype='int64') + pcode = fluid.layers.data(name='pcode', shape=[40], dtype='int64') + data_list.append(pcode) + data_list.append(ptable) + if non_leaf_num == None: + non_leaf_num = dict_size + + cost = fluid.layers.hsigmoid( + input=emb, + label=predict_word, + non_leaf_num=non_leaf_num, + ptable=ptable, + pcode=pcode, + is_costum=True) return cost input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64') - predict_word = fluid.layers.data(name='predict_word', shape=[1], dtype='int64') + predict_word = fluid.layers.data( + name='predict_word', shape=[1], dtype='int64') + cost = None data_list = [input_word, predict_word] emb = fluid.layers.embedding( input=input_word, size=[dict_size, embedding_size], - param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(scale=1 / math.sqrt(dict_size)))) + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(dict_size)))) + + if with_nce: + cost = nce_layer(emb, predict_word, embedding_size, dict_size, 5, + "uniform", word_frequencys, None) + if with_hsigmoid: + cost = hsigmoid_layer(emb, predict_word, dict_size, max_code_length, + data_list) - cost = nce_layer(emb, predict_word, embedding_size, dict_size, 5, "uniform", word_frequencys, None) avg_cost = fluid.layers.reduce_mean(cost) return avg_cost, data_list diff --git a/fluid/PaddleRec/word2vec/preprocess.py b/fluid/PaddleRec/word2vec/preprocess.py index f19ac768..89645c2f 100644 --- a/fluid/PaddleRec/word2vec/preprocess.py +++ b/fluid/PaddleRec/word2vec/preprocess.py @@ -30,6 +30,90 @@ def text_strip(text): return re.sub("[^a-z ]", "", text) +def build_Huffman(word_count, max_code_length): + + MAX_CODE_LENGTH = max_code_length + sorted_by_freq = sorted(word_count.items(), key=lambda x: x[1]) + count = list() + vocab_size = len(word_count) + parent = [-1] * 2 * vocab_size + code = [-1] * MAX_CODE_LENGTH + point = [-1] * MAX_CODE_LENGTH + binary = [-1] * 2 * vocab_size + word_code_len = dict() + word_code = dict() + word_point = dict() + i = 0 + for a in range(vocab_size): + count.append(word_count[sorted_by_freq[a][0]]) + + for a in range(vocab_size): + word_point[sorted_by_freq[a][0]] = [-1] * MAX_CODE_LENGTH + word_code[sorted_by_freq[a][0]] = [-1] * MAX_CODE_LENGTH + + for k in range(vocab_size): + count.append(1e15) + + pos1 = vocab_size - 1 + pos2 = vocab_size + min1i = 0 + min2i = 0 + b = 0 + + for r in range(vocab_size): + if pos1 >= 0: + if count[pos1] < count[pos2]: + min1i = pos1 + pos1 = pos1 - 1 + else: + min1i = pos2 + pos2 = pos2 + 1 + else: + min1i = pos2 + pos2 = pos2 + 1 + if pos1 >= 0: + if count[pos1] < count[pos2]: + min2i = pos1 + pos1 = pos1 - 1 + else: + min2i = pos2 + pos2 = pos2 + 1 + else: + min2i = pos2 + pos2 = pos2 + 1 + + count[vocab_size + r] = count[min1i] + count[min2i] + + #record the parent of left and right child + parent[min1i] = vocab_size + r + parent[min2i] = vocab_size + r + binary[min1i] = 0 #left branch has code 0 + binary[min2i] = 1 #right branch has code 1 + + for a in range(vocab_size): + b = a + i = 0 + while True: + code[i] = binary[b] + point[i] = b + i = i + 1 + b = parent[b] + if b == vocab_size * 2 - 2: + break + + word_code_len[sorted_by_freq[a][0]] = i + word_point[sorted_by_freq[a][0]][0] = vocab_size - 2 + + for k in range(i): + word_code[sorted_by_freq[a][0]][i - k - 1] = code[k] + + # only non-leaf nodes will be count in + if point[k] - vocab_size >= 0: + word_point[sorted_by_freq[a][0]][i - k] = point[k] - vocab_size + + return word_point, word_code, word_code_len + + def preprocess(data_path, dict_path, freq): """ proprocess the data, generate dictionary and save into dict_path. @@ -58,10 +142,20 @@ def preprocess(data_path, dict_path, freq): for item in item_to_remove: del word_count[item] + path_table, path_code, word_code_len = build_Huffman(word_count, 40) + with open(dict_path, 'w+') as f: for k, v in word_count.items(): f.write(str(k) + " " + str(v) + '\n') + with open(dict_path + "_ptable", 'w+') as f2: + for pk, pv in path_table.items(): + f2.write(str(pk) + ":" + ' '.join((str(x) for x in pv)) + '\n') + + with open(dict_path + "_pcode", 'w+') as f3: + for pck, pcv in path_table.items(): + f3.write(str(pck) + ":" + ' '.join((str(x) for x in pcv)) + '\n') + if __name__ == "__main__": args = parse_args() diff --git a/fluid/PaddleRec/word2vec/reader.py b/fluid/PaddleRec/word2vec/reader.py index 6f411283..04600044 100644 --- a/fluid/PaddleRec/word2vec/reader.py +++ b/fluid/PaddleRec/word2vec/reader.py @@ -8,7 +8,11 @@ class Word2VecReader(object): def __init__(self, dict_path, data_path, window_size=5): self.window_size_ = window_size self.data_path_ = data_path + self.num_non_leaf = 0 self.word_to_id_ = dict() + self.id_to_word = dict() + self.word_to_path = dict() + self.word_to_code = dict() word_all_count = 0 word_counts = [] @@ -18,13 +22,31 @@ class Word2VecReader(object): for line in f: word, count = line.split()[0], int(line.split()[1]) self.word_to_id_[word] = word_id + self.id_to_word[word_id] = word #build id to word dict word_id += 1 word_counts.append(count) word_all_count += count self.dict_size = len(self.word_to_id_) - self.word_frequencys = [ float(count)/word_all_count for count in word_counts] - print("dict_size = " + str(self.dict_size)) + " word_all_count = " + str(word_all_count) + self.word_frequencys = [ + float(count) / word_all_count for count in word_counts + ] + print("dict_size = " + str( + self.dict_size)) + " word_all_count = " + str(word_all_count) + + with open(dict_path + "_ptable", 'r') as f2: + for line in f2: + self.word_to_path[line.split(":")[0]] = np.fromstring( + line.split(':')[1], dtype=int, sep=' ') + self.num_non_leaf = np.fromstring( + line.split(':')[1], dtype=int, sep=' ')[0] + print("word_ptable dict_size = " + str(len(self.word_to_path))) + + with open(dict_path + "_pcode", 'r') as f3: + for line in f3: + self.word_to_code[line.split(":")[0]] = np.fromstring( + line.split(':')[1], dtype=int, sep=' ') + print("word_pcode dict_size = " + str(len(self.word_to_code))) def get_context_words(self, words, idx, window_size): """ @@ -42,7 +64,7 @@ class Word2VecReader(object): targets = set(words[start_point:idx] + words[idx + 1:end_point + 1]) return list(targets) - def train(self): + def train(self, with_hs): def _reader(): with open(self.data_path_, 'r') as f: for line in f: @@ -57,7 +79,28 @@ class Word2VecReader(object): for context_id in context_word_ids: yield [target_id], [context_id] - return _reader + def _reader_hs(): + with open(self.data_path_, 'r') as f: + for line in f: + line = preprocess.text_strip(line) + word_ids = [ + self.word_to_id_[word] for word in line.split() + if word in self.word_to_id_ + ] + for idx, target_id in enumerate(word_ids): + context_word_ids = self.get_context_words( + word_ids, idx, self.window_size_) + for context_id in context_word_ids: + yield [target_id], [context_id], [ + self.word_to_code[self.id_to_word[context_id]] + ], [ + self.word_to_path[self.id_to_word[context_id]] + ] + + if not with_hs: + return _reader + else: + return _reader_hs if __name__ == "__main__": diff --git a/fluid/PaddleRec/word2vec/train.py b/fluid/PaddleRec/word2vec/train.py index 8581a9c3..9d895561 100644 --- a/fluid/PaddleRec/word2vec/train.py +++ b/fluid/PaddleRec/word2vec/train.py @@ -3,6 +3,7 @@ from __future__ import print_function import argparse import logging import os +import time # disable gpu training for this example os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -87,6 +88,21 @@ def parse_args(): type=int, default=1, help='The num of trianers, (default: 1)') + parser.add_argument( + '--with_hs', + type=int, + default=0, + help='using hierarchical sigmoid, (default: 0)') + parser.add_argument( + '--with_nce', + type=int, + default=1, + help='using negtive sampling, (default: 1)') + parser.add_argument( + '--max_code_length', + type=int, + default=40, + help='max code length used by hierarchical sigmoid, (default: 40)') return parser.parse_args() @@ -95,15 +111,18 @@ def train_loop(args, train_program, reader, data_list, loss, trainer_num, trainer_id): train_reader = paddle.batch( paddle.reader.shuffle( - reader.train(), buf_size=args.batch_size * 100), + reader.train((args.with_hs or (not args.with_nce))), + buf_size=args.batch_size * 100), batch_size=args.batch_size) place = fluid.CPUPlace() feeder = fluid.DataFeeder(feed_list=data_list, place=place) + data_name_list = [var.name for var in data_list] exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + start = time.clock() for pass_id in range(args.num_passes): for batch_id, data in enumerate(train_reader()): loss_val = exe.run(train_program, @@ -112,6 +131,10 @@ def train_loop(args, train_program, reader, data_list, loss, trainer_num, if batch_id % 10 == 0: logger.info("TRAIN --> pass: {} batch: {} loss: {}".format( pass_id, batch_id, loss_val[0] / args.batch_size)) + if batch_id % 1000 == 0 and batch_id != 0: + elapsed = (time.clock() - start) + logger.info("Time used: {}".format(elapsed)) + if batch_id % 1000 == 0 and batch_id != 0: model_dir = args.model_output_dir + '/batch-' + str(batch_id) if args.trainer_id == 0: @@ -133,9 +156,12 @@ def train(): args.train_data_path) logger.info("dict_size: {}".format(word2vec_reader.dict_size)) - logger.info("word_frequencys length: {}".format(len(word2vec_reader.word_frequencys))) + logger.info("word_frequencys length: {}".format( + len(word2vec_reader.word_frequencys))) - loss, data_list = skip_gram_word2vec(word2vec_reader.dict_size, word2vec_reader.word_frequencys, args.embedding_size) + loss, data_list = skip_gram_word2vec( + word2vec_reader.dict_size, word2vec_reader.word_frequencys, + args.embedding_size, args.max_code_length, args.with_hs, args.with_nce) optimizer = fluid.optimizer.Adam(learning_rate=1e-3) optimizer.minimize(loss) -- GitLab