From e99b369d8f4fe966b59cb634e18677b46810e937 Mon Sep 17 00:00:00 2001 From: zhangwenhui03 Date: Thu, 21 Mar 2019 19:35:50 +0800 Subject: [PATCH] word2vec new version --- fluid/PaddleRec/word2vec/README.cn.md | 99 ----- fluid/PaddleRec/word2vec/README.md | 110 ++--- fluid/PaddleRec/word2vec/__init__.py | 0 fluid/PaddleRec/word2vec/cluster_train.sh | 41 -- fluid/PaddleRec/word2vec/data/download.sh | 4 - fluid/PaddleRec/word2vec/infer.py | 467 +++++++++------------- fluid/PaddleRec/word2vec/net.py | 144 +++++++ fluid/PaddleRec/word2vec/network_conf.py | 129 ------ fluid/PaddleRec/word2vec/preprocess.py | 273 +++++-------- fluid/PaddleRec/word2vec/reader.py | 120 +----- fluid/PaddleRec/word2vec/train.py | 291 ++++---------- fluid/PaddleRec/word2vec/utils.py | 96 +++++ 12 files changed, 673 insertions(+), 1101 deletions(-) delete mode 100644 fluid/PaddleRec/word2vec/README.cn.md delete mode 100644 fluid/PaddleRec/word2vec/__init__.py delete mode 100644 fluid/PaddleRec/word2vec/cluster_train.sh delete mode 100644 fluid/PaddleRec/word2vec/data/download.sh create mode 100644 fluid/PaddleRec/word2vec/net.py delete mode 100644 fluid/PaddleRec/word2vec/network_conf.py create mode 100644 fluid/PaddleRec/word2vec/utils.py diff --git a/fluid/PaddleRec/word2vec/README.cn.md b/fluid/PaddleRec/word2vec/README.cn.md deleted file mode 100644 index 13e79c41..00000000 --- a/fluid/PaddleRec/word2vec/README.cn.md +++ /dev/null @@ -1,99 +0,0 @@ - -# 基于skip-gram的word2vector模型 - -## 介绍 - - -## 运行环境 -需要先安装PaddlePaddle Fluid - -## 数据集 -数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark)的数据集. - -下载数据集: -```bash -cd data && ./download.sh && cd .. -``` - -## 模型 -本例子实现了一个skip-gram模式的word2vector模型。 - - -## 数据准备 -对数据进行预处理以生成一个词典。 - -```bash -python preprocess.py --data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled --dict_path data/1-billion_dict -``` -如果您想使用自定义的词典形如: -```bash - -a -b -c -``` -请将--other_dict_path设置为您存放将使用的词典的目录,并设置--with_other_dict使用它 - -## 训练 -训练的命令行选项可以通过`python train.py -h`列出。 - -### 单机训练: - -```bash -export CPU_NUM=1 -python train.py \ - --train_data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled \ - --dict_path data/1-billion_dict \ - --with_hs --with_nce --is_local \ - 2>&1 | tee train.log -``` -如果您想使用自定义的词典形如: -```bash - -a -b -c -``` -请将--other_dict_path设置为您存放将使用的词典的目录,并设置--with_other_dict使用它 - -### 分布式训练 - -本地启动一个2 trainer 2 pserver的分布式训练任务,分布式场景下训练数据会按照trainer的id进行切分,保证trainer之间的训练数据不会重叠,提高训练效率 - -```bash -sh cluster_train.sh -``` - -## 预测 -在infer.py中我们在`build_test_case`方法中构造了一些test case来评估word embeding的效果: -我们输入test case( 我们目前采用的是analogical-reasoning的任务:找到A - B = C - D的结构,为此我们计算A - B + D,通过cosine距离找最近的C,计算准确率要去除候选中出现A、B、D的候选 )然后计算候选和整个embeding中所有词的余弦相似度,并且取topK(K由参数 --rank_num确定,默认为4)打印出来。 - -如: -对于:boy - girl + aunt = uncle -0 nearest aunt:0.89 -1 nearest uncle:0.70 -2 nearest grandmother:0.67 -3 nearest father:0.64 - -您也可以在`build_test_case`方法中模仿给出的例子增加自己的测试 - -要从测试文件运行测试用例,请将测试文件下载到“test”目录中 -我们为每个案例提供以下结构的测试: - `word1 word2 word3 word4` -所以我们可以将它构建成`word1 - word2 + word3 = word4` - -训练中预测: - -```bash -python infer.py --infer_during_train 2>&1 | tee infer.log -``` -使用某个model进行离线预测: - -```bash -python infer.py --infer_once --model_output_dir ./models/[具体的models文件目录] 2>&1 | tee infer.log -``` -## 在百度云上运行集群训练 -1. 参考文档 [在百度云上启动Fluid分布式训练](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst) 在百度云上部署一个CPU集群。 -1. 用preprocess.py处理训练数据生成train.txt。 -1. 将train.txt切分成集群机器份,放到每台机器上。 -1. 用上面的 `分布式训练` 中的命令行启动分布式训练任务. diff --git a/fluid/PaddleRec/word2vec/README.md b/fluid/PaddleRec/word2vec/README.md index 17a61a42..812fe1e3 100644 --- a/fluid/PaddleRec/word2vec/README.md +++ b/fluid/PaddleRec/word2vec/README.md @@ -1,103 +1,73 @@ +# 基于skip-gram的word2vector模型 -# Skip-Gram Word2Vec Model +以下是本例的简要目录结构及说明: -## Introduction +```text +. +├── infer.py # 预测脚本 +├── net.py # 网络结构 +├── preprocess.py # 预处理脚本,包括构建词典和预处理文本 +├── reader.py # 训练阶段的文本读写 +├── README.md # 使用说明 +├── train.py # 训练函数 +└── utils.py # 通用函数 +``` + +## 介绍 +本例实现了skip-gram模式的word2vector模型。 -## Environment -You should install PaddlePaddle Fluid first. -## Dataset -The training data for the 1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark). +## 数据集 +大数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark)的数据集.下载命令如下 -Download dataset: ```bash -cd data && ./download.sh && cd .. +wget https://paddle-zwh.bj.bcebos.com/1-billion-word-language-modeling-benchmark-r13output.tar ``` -if you would like to use our supported third party vocab, please run: +小数据集使用1700w个词的text8数据集,下载命令如下 + +下载数据集: ```bash -wget http://download.tensorflow.org/models/LM_LSTM_CNN/vocab-2016-09-10.txt +wget https://paddle-zwh.bj.bcebos.com/text.tar ``` -## Model -This model implement a skip-gram model of word2vector. +## 数据预处理 +以下以小数据为例进行预处理。 -## Data Preprocessing method +大数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和小数据集的text目录并列。 -Preprocess the training data to generate a word dict. +根据英文语料生成词典, 中文语料可以通过修改text_strip ```bash -python preprocess.py --data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled --dict_path data/1-billion_dict +python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict ``` -if you would like to use your own vocab follow the format below: -```bash - -a -b -c -``` -Then, please set --other_dict_path as the directory of where you -save the vocab you will use and set --with_other_dict flag on to using it. -## Train -The command line options for training can be listed by `python train.py -h`. +根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词。 -### Local Train: -we set CPU_NUM=1 as default CPU_NUM to execute ```bash -export CPU_NUM=1 && \ -python train.py \ - --train_data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled \ - --dict_path data/1-billion_dict \ - --with_hs --with_nce --is_local \ - 2>&1 | tee train.log +python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text/ --output_corpus_dir data/convert_text8 --min_count 5 --downsample 0.001 ``` -if you would like to use our supported third party vocab, please set --other_dict_path as the directory of where you -save the vocab you will use and set --with_other_dict flag on to using it. -### Distributed Train -Run a 2 pserver 2 trainer distribute training on a single machine. -In distributed training setting, training data is splited by trainer_id, so that training data - do not overlap among trainers +## 训练 +cpu 单机多线程训练 ```bash -sh cluster_train.sh +OPENBLAS_NUM_THREADS=1 CPU_NUM=5 python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes 10 --batch_size 100 --model_output_dir v1_cpu5_b100_lr1dir --base_lr 1.0 --print_batch 1000 --with_speed --is_sparse ``` -## Infer - -In infer.py we construct some test cases in the `build_test_case` method to evaluate the effect of word embeding: -We enter the test case (we are currently using the analogical-reasoning task: find the structure of A - B = C - D, for which we calculate A - B + D, find the nearest C by cosine distance, the calculation accuracy is removed Candidates for A, B, and D appear in the candidate) Then calculate the cosine similarity of the candidate and all words in the entire embeding, and print out the topK (K is determined by the parameter --rank_num, the default is 4). - -Such as: -For: boy - girl + aunt = uncle -0 nearest aunt: 0.89 -1 nearest uncle: 0.70 -2 nearest grandmother: 0.67 -3 nearest father:0.64 - -You can also add your own tests by mimicking the examples given in the `build_test_case` method. - -To running test case from test files, please download the test files into 'test' directory -we provide test for each case with the following structure: - `word1 word2 word3 word4` -so we can build it into `word1 - word2 + word3 = word4` - -Forecast in training: +## 预测 +测试集下载命令如下 ```bash -Python infer.py --infer_during_train 2>&1 | tee infer.log +#大数据集测试集 +wget https://paddle-zwh.bj.bcebos.com/test_dir.tar +#小数据集测试集 +wget https://paddle-zwh.bj.bcebos.com/test_mid_dir.tar ``` -Use a model for offline prediction: +预测命令 ```bash -Python infer.py --infer_once --model_output_dir ./models/[specific models file directory] 2>&1 | tee infer.log +python infer.py --infer_epoch --test_dir data/test_mid_dir/ --dict_path data/test_build_dict_word_to_id_ --batch_size 20000 --model_dir v1_cpu5_b100_lr1dir/ --start_index 0 ``` - -## Train on Baidu Cloud -1. Please prepare some CPU machines on Baidu Cloud following the steps in [train_on_baidu_cloud](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst) -1. Prepare dataset using preprocess.py. -1. Split the train.txt to trainer_num parts and put them on the machines. -1. Run training with the cluster train using the command in `Distributed Train` above. diff --git a/fluid/PaddleRec/word2vec/__init__.py b/fluid/PaddleRec/word2vec/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fluid/PaddleRec/word2vec/cluster_train.sh b/fluid/PaddleRec/word2vec/cluster_train.sh deleted file mode 100644 index 87537efb..00000000 --- a/fluid/PaddleRec/word2vec/cluster_train.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -echo "WARNING: This script only for run PaddlePaddle Fluid on one node..." -echo "" - -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib -export PADDLE_PSERVER_PORTS=36001,36002 -export PADDLE_PSERVER_PORT_ARRAY=(36001 36002) -export PADDLE_PSERVERS=2 - -export PADDLE_IP=127.0.0.1 -export PADDLE_TRAINERS=2 - -export CPU_NUM=2 -export NUM_THREADS=2 -export PADDLE_SYNC_MODE=TRUE -export PADDLE_IS_LOCAL=0 - -export FLAGS_rpc_deadline=3000000 -export GLOG_logtostderr=1 - - -export TRAIN_DATA=data/enwik8 -export DICT_PATH=data/enwik8_dict -export IS_SPARSE="--is_sparse" - - -echo "Start PSERVER ..." -for((i=0;i<$PADDLE_PSERVERS;i++)) -do - cur_port=${PADDLE_PSERVER_PORT_ARRAY[$i]} - echo "PADDLE WILL START PSERVER "$cur_port - GLOG_v=0 PADDLE_TRAINING_ROLE=PSERVER CUR_PORT=$cur_port PADDLE_TRAINER_ID=$i python -u train.py $IS_SPARSE &> pserver.$i.log & -done - -echo "Start TRAINER ..." -for((i=0;i<$PADDLE_TRAINERS;i++)) -do - echo "PADDLE WILL START Trainer "$i - GLOG_v=0 PADDLE_TRAINER_ID=$i PADDLE_TRAINING_ROLE=TRAINER python -u train.py $IS_SPARSE --train_data_path $TRAIN_DATA --dict_path $DICT_PATH &> trainer.$i.log & -done \ No newline at end of file diff --git a/fluid/PaddleRec/word2vec/data/download.sh b/fluid/PaddleRec/word2vec/data/download.sh deleted file mode 100644 index 22cde6d9..00000000 --- a/fluid/PaddleRec/word2vec/data/download.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz -tar -zxvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz diff --git a/fluid/PaddleRec/word2vec/infer.py b/fluid/PaddleRec/word2vec/infer.py index c0dd82ef..9a364950 100644 --- a/fluid/PaddleRec/word2vec/infer.py +++ b/fluid/PaddleRec/word2vec/infer.py @@ -1,294 +1,213 @@ +import argparse +import sys import time -import os -import paddle.fluid as fluid +import math +import unittest +import contextlib import numpy as np -import logging -import argparse -import preprocess - -word_to_id = dict() -id_to_word = dict() - -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger("fluid") -logger.setLevel(logging.INFO) +import six +import paddle.fluid as fluid +import paddle +import net +import utils def parse_args(): - parser = argparse.ArgumentParser( - description="PaddlePaddle Word2vec infer example") + parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example") parser.add_argument( '--dict_path', type=str, - default='./data/1-billion_dict', - help="The path of training dataset") + default='./data/data_c/1-billion_dict_word_to_id_', + help="The path of dic") parser.add_argument( - '--model_output_dir', - type=str, - default='models', - help="The path for model to store (with infer_once please set specify dir to models) (default: models)" - ) - parser.add_argument( - '--rank_num', - type=int, - default=4, - help="find rank_num-nearest result for test (default: 4)") - parser.add_argument( - '--infer_once', + '--infer_epoch', action='store_true', required=False, default=False, - help='if using infer_once, (default: False)') - parser.add_argument( - '--infer_during_train', - action='store_true', - required=False, - default=True, - help='if using infer_during_train, (default: True)') + help='infer by epoch') parser.add_argument( - '--test_acc', + '--infer_step', action='store_true', required=False, default=False, - help='if using test_files , (default: False)') + help='infer by step') parser.add_argument( - '--test_files_dir', - type=str, - default='test', - help="The path for test_files) (default: test)") + '--test_dir', type=str, default='test_data', help='test file address') parser.add_argument( - '--test_batch_size', - type=int, - default=1000, - help="test used batch size (default: 1000)") - - return parser.parse_args() - - -def BuildWord_IdMap(dict_path): - with open(dict_path + "_word_to_id_", 'r') as f: - for line in f: - word_to_id[line.split(' ')[0]] = int(line.split(' ')[1]) - id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] - - -def inference_prog(): # just to create program for test - fluid.layers.create_parameter( - shape=[1, 1], dtype='float32', name="embeding") - - -def build_test_case_from_file(args, emb): - logger.info("test files dir: {}".format(args.test_files_dir)) - current_list = os.listdir(args.test_files_dir) - logger.info("test files list: {}".format(current_list)) - test_cases = list() - test_labels = list() - test_case_descs = list() - exclude_lists = list() - for file_dir in current_list: - with open(args.test_files_dir + "/" + file_dir, 'r') as f: - for line in f: - if ':' in line: - logger.info("{}".format(line)) - pass - else: - line = preprocess.strip_lines(line, word_to_id) - test_case = emb[word_to_id[line.split()[0]]] - emb[ - word_to_id[line.split()[1]]] + emb[word_to_id[ - line.split()[2]]] - test_case_desc = line.split()[0] + " - " + line.split()[ - 1] + " + " + line.split()[2] + " = " + line.split()[3] - test_cases.append(test_case) - test_case_descs.append(test_case_desc) - test_labels.append(word_to_id[line.split()[3]]) - exclude_lists.append([ - word_to_id[line.split()[0]], - word_to_id[line.split()[1]], word_to_id[line.split()[2]] - ]) - test_cases = norm(np.array(test_cases)) - return test_cases, test_case_descs, test_labels, exclude_lists - - -def build_small_test_case(emb): - emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[ - 'aunt']] - desc1 = "boy - girl + aunt = uncle" - label1 = word_to_id["uncle"] - emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[ - word_to_id['sisters']] - desc2 = "brother - sister + sisters = brothers" - label2 = word_to_id["brothers"] - emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[ - 'woman']] - desc3 = "king - queen + woman = man" - label3 = word_to_id["man"] - emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[ - word_to_id['slowly']] - desc4 = "reluctant - reluctantly + slowly = slow" - label4 = word_to_id["slow"] - emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[ - 'deeper']] - desc5 = "old - older + deeper = deep" - label5 = word_to_id["deep"] - - emb6 = emb[word_to_id['boy']] - desc6 = "boy" - label6 = word_to_id["boy"] - emb7 = emb[word_to_id['king']] - desc7 = "king" - label7 = word_to_id["king"] - emb8 = emb[word_to_id['sun']] - desc8 = "sun" - label8 = word_to_id["sun"] - emb9 = emb[word_to_id['key']] - desc9 = "key" - label9 = word_to_id["key"] - test_cases = [emb1, emb2, emb3, emb4, emb5, emb6, emb7, emb8, emb9] - test_case_desc = [ - desc1, desc2, desc3, desc4, desc5, desc6, desc7, desc8, desc9 - ] - test_labels = [ - label1, label2, label3, label4, label5, label6, label7, label8, label9 - ] - return norm(np.array(test_cases)), test_case_desc, test_labels - - -def build_test_case(args, emb): - if args.test_acc: - return build_test_case_from_file(args, emb) - else: - return build_small_test_case(emb) - - -def norm(x): - y = np.linalg.norm(x, axis=1, keepdims=True) - return x / y - - -def inference_test(scope, model_dir, args): - BuildWord_IdMap(args.dict_path) - logger.info("model_dir is: {}".format(model_dir + "/")) - emb = np.array(scope.find_var("embeding").get_tensor()) - x = norm(emb) - logger.info("inference result: ====================") - test_cases = None - test_case_desc = list() - test_labels = list() - exclude_lists = list() - if args.test_acc: - test_cases, test_case_desc, test_labels, exclude_lists = build_test_case( - args, emb) - else: - test_cases, test_case_desc, test_labels = build_test_case(args, emb) - exclude_lists = [[-1]] - accual_rank = 1 if args.test_acc else args.rank_num - correct_num = 0 - cosine_similarity_matrix = np.dot(test_cases, x.T) - results = topKs(accual_rank, cosine_similarity_matrix, exclude_lists, - args.test_acc) - for i in range(len(test_labels)): - logger.info("Test result for {}".format(test_case_desc[i])) - result = results[i] - for j in range(accual_rank): - if result[j][1] == test_labels[ - i]: # if the nearest word is what we want - correct_num += 1 - logger.info("{} nearest is {}, rate is {}".format(j, id_to_word[ - result[j][1]], result[j][0])) - logger.info("Test acc is: {}, there are {} / {}".format(correct_num / len( - test_labels), correct_num, len(test_labels))) - - -def topK(k, cosine_similarity_list, exclude_list, is_acc=False): - if k == 1 and is_acc: # accelerate acc calculate - max = cosine_similarity_list[0] - id = 0 - for i in range(len(cosine_similarity_list)): - if cosine_similarity_list[i] >= max and (i not in exclude_list): - max = cosine_similarity_list[i] - id = i - else: - pass - return [[max, id]] - else: - result = list() - result_index = np.argpartition(cosine_similarity_list, -k)[-k:] - for index in result_index: - result.append([cosine_similarity_list[index], index]) - result.sort(reverse=True) - return result - - -def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False): - results = list() - result_queues = list() - correct_num = 0 - - for i in range(cosine_similarity_matrix.shape[0]): - tmp_pq = None - if is_acc: - tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[i], - is_acc) - else: - tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0], - is_acc) - result_queues.append(tmp_pq) - return result_queues - - -def infer_during_train(args): - model_file_list = list() - exe = fluid.Executor(fluid.CPUPlace()) - Scope = fluid.Scope() - inference_prog() - solved_new = True - while True: - time.sleep(60) - current_list = os.listdir(args.model_output_dir) - if set(model_file_list) == set(current_list): - if solved_new: - solved_new = False - logger.info("No New models created") - pass - else: - solved_new = True - increment_models = list() - for f in current_list: - if f not in model_file_list: - increment_models.append(f) - logger.info("increment_models is : {}".format(increment_models)) - for model in increment_models: - model_dir = args.model_output_dir + "/" + model - if os.path.exists(model_dir + "/_success"): - logger.info("using models from " + model_dir) - with fluid.scope_guard(Scope): - fluid.io.load_persistables( - executor=exe, dirname=model_dir + "/") - inference_test(Scope, model_dir, args) - model_file_list = current_list - - -def infer_once(args): - # check models file has already been finished - if os.path.exists(args.model_output_dir + "/_success"): - logger.info("using models from " + args.model_output_dir) - exe = fluid.Executor(fluid.CPUPlace()) - Scope = fluid.Scope() - inference_prog() - with fluid.scope_guard(Scope): - fluid.io.load_persistables( - executor=exe, dirname=args.model_output_dir + "/") - inference_test(Scope, args.model_output_dir, args) - else: - logger.info("Wrong Directory or save model failed!") - - -if __name__ == '__main__': + '--print_step', type=int, default='500000', help='print step') + parser.add_argument( + '--start_index', type=int, default='0', help='start index') + parser.add_argument( + '--start_batch', type=int, default='1', help='start index') + parser.add_argument( + '--end_batch', type=int, default='13', help='start index') + parser.add_argument( + '--last_index', type=int, default='100', help='last index') + parser.add_argument( + '--model_dir', type=str, default='model', help='model dir') + parser.add_argument( + '--use_cuda', type=int, default='0', help='whether use cuda') + parser.add_argument( + '--batch_size', type=int, default='5', help='batch_size') + parser.add_argument('--emb_size', type=int, default='64', help='batch_size') + args = parser.parse_args() + return args + + +def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w): + """ inference function """ + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + emb_size = args.emb_size + batch_size = args.batch_size + with fluid.scope_guard(fluid.core.Scope()): + main_program = fluid.Program() + with fluid.program_guard(main_program): + values, pred = net.infer_network(vocab_size, emb_size) + for epoch in range(start_index, last_index + 1): + copy_program = main_program.clone() + model_path = model_dir + "/pass-" + str(epoch) + fluid.io.load_params( + executor=exe, dirname=model_path, main_program=copy_program) + accum_num = 0 + accum_num_sum = 0.0 + t0 = time.time() + step_id = 0 + for data in test_reader(): + step_id += 1 + b_size = len([dat[0] for dat in data]) + wa = np.array( + [dat[0] for dat in data]).astype("int64").reshape( + b_size, 1) + wb = np.array( + [dat[1] for dat in data]).astype("int64").reshape( + b_size, 1) + wc = np.array( + [dat[2] for dat in data]).astype("int64").reshape( + b_size, 1) + + label = [dat[3] for dat in data] + input_word = [dat[4] for dat in data] + para = exe.run( + copy_program, + feed={ + "analogy_a": wa, + "analogy_b": wb, + "analogy_c": wc, + "all_label": + np.arange(vocab_size).reshape(vocab_size, 1), + }, + fetch_list=[pred.name, values], + return_numpy=False) + pre = np.array(para[0]) + val = np.array(para[1]) + for ii in range(len(label)): + top4 = pre[ii] + accum_num_sum += 1 + for idx in top4: + if int(idx) in input_word[ii]: + continue + if int(idx) == int(label[ii][0]): + accum_num += 1 + break + if step_id % 1 == 0: + print("step:%d %d " % (step_id, accum_num)) + + print("epoch:%d \t acc:%.3f " % + (epoch, 1.0 * accum_num / accum_num_sum)) + + +def infer_step(args, vocab_size, test_reader, use_cuda, i2w): + """ inference function """ + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + emb_size = args.emb_size + batch_size = args.batch_size + with fluid.scope_guard(fluid.core.Scope()): + main_program = fluid.Program() + with fluid.program_guard(main_program): + values, pred = net.infer_network(vocab_size, emb_size) + for epoch in range(start_index, last_index + 1): + for batchid in range(args.start_batch, args.end_batch): + copy_program = main_program.clone() + model_path = model_dir + "/pass-" + str(epoch) + ( + '/batch-' + str(batchid * args.print_step)) + fluid.io.load_params( + executor=exe, + dirname=model_path, + main_program=copy_program) + accum_num = 0 + accum_num_sum = 0.0 + t0 = time.time() + step_id = 0 + for data in test_reader(): + step_id += 1 + b_size = len([dat[0] for dat in data]) + wa = np.array( + [dat[0] for dat in data]).astype("int64").reshape( + b_size, 1) + wb = np.array( + [dat[1] for dat in data]).astype("int64").reshape( + b_size, 1) + wc = np.array( + [dat[2] for dat in data]).astype("int64").reshape( + b_size, 1) + + label = [dat[3] for dat in data] + input_word = [dat[4] for dat in data] + para = exe.run( + copy_program, + feed={ + "analogy_a": wa, + "analogy_b": wb, + "analogy_c": wc, + "all_label": + np.arange(vocab_size).reshape(vocab_size, 1), + }, + fetch_list=[pred.name, values], + return_numpy=False) + pre = np.array(para[0]) + val = np.array(para[1]) + for ii in range(len(label)): + top4 = pre[ii] + accum_num_sum += 1 + for idx in top4: + if int(idx) in input_word[ii]: + continue + if int(idx) == int(label[ii][0]): + accum_num += 1 + break + if step_id % 1 == 0: + print("step:%d %d " % (step_id, accum_num)) + print("epoch:%d \t acc:%.3f " % + (epoch, 1.0 * accum_num / accum_num_sum)) + t1 = time.time() + + +if __name__ == "__main__": args = parse_args() - # while setting infer_once please specify the dir to models file with --model_output_dir - if args.infer_once: - infer_once(args) - elif args.infer_during_train: - infer_during_train(args) + start_index = args.start_index + last_index = args.last_index + test_dir = args.test_dir + model_dir = args.model_dir + batch_size = args.batch_size + dict_path = args.dict_path + use_cuda = True if args.use_cuda else False + print("start index: ", start_index, " last_index:", last_index) + vocab_size, test_reader, id2word = utils.prepare_data( + test_dir, dict_path, batch_size=batch_size) + print("vocab_size:", vocab_size) + if args.infer_step: + infer_step( + args, + vocab_size, + test_reader=test_reader, + use_cuda=use_cuda, + i2w=id2word) else: - pass + infer_epoch( + args, + vocab_size, + test_reader=test_reader, + use_cuda=use_cuda, + i2w=id2word) diff --git a/fluid/PaddleRec/word2vec/net.py b/fluid/PaddleRec/word2vec/net.py new file mode 100644 index 00000000..b900c62a --- /dev/null +++ b/fluid/PaddleRec/word2vec/net.py @@ -0,0 +1,144 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +neural network for word2vec +""" +from __future__ import print_function +import math +import numpy as np +import paddle.fluid as fluid + + +def skip_gram_word2vec(dict_size, + word_frequencys, + embedding_size, + is_sparse=False): + + datas = [] + neg_num = 5 + input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64') + true_word = fluid.layers.data(name='true_label', shape=[1], dtype='int64') + neg_word = fluid.layers.data( + name="neg_label", shape=[neg_num, 5], dtype='int64') + + datas.append(input_word) + datas.append(true_word) + datas.append(neg_word) + + py_reader = fluid.layers.create_py_reader_by_data( + capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True) + + words = fluid.layers.read_file(py_reader) + init_width = 0.5 / embedding_size + input_emb = fluid.layers.embedding( + input=words[0], + is_sparse=is_sparse, + size=[dict_size, embedding_size], + param_attr=fluid.ParamAttr( + name='emb', + initializer=fluid.initializer.Uniform(-init_width, init_width))) + + true_emb_w = fluid.layers.embedding( + input=words[1], + is_sparse=is_sparse, + size=[dict_size, embedding_size], + param_attr=fluid.ParamAttr( + name='emb_w', initializer=fluid.initializer.Constant(value=0.0))) + + true_emb_b = fluid.layers.embedding( + input=words[1], + is_sparse=is_sparse, + size=[dict_size, 1], + param_attr=fluid.ParamAttr( + name='emb_b', initializer=fluid.initializer.Constant(value=0.0))) + neg_word_reshape = fluid.layers.reshape(words[2], shape=[-1, 1]) + neg_word_reshape.stop_gradient = True + + neg_emb_w = fluid.layers.embedding( + input=neg_word_reshape, + is_sparse=is_sparse, + size=[dict_size, embedding_size], + param_attr=fluid.ParamAttr( + name='emb_w', learning_rate=10.0)) + # param_attr='emb_w') + + neg_emb_w_re = fluid.layers.reshape( + neg_emb_w, shape=[-1, neg_num, embedding_size]) + neg_emb_b = fluid.layers.embedding( + input=neg_word_reshape, + is_sparse=is_sparse, + size=[dict_size, 1], + param_attr=fluid.ParamAttr( + name='emb_b', learning_rate=10.0)) + #param_attr='emb_b') + + neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num]) + true_logits = fluid.layers.elementwise_add( + fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(input_emb, true_emb_w), + dim=1, + keep_dim=True), + true_emb_b) + input_emb_re = fluid.layers.reshape( + input_emb, shape=[-1, 1, embedding_size]) + neg_matmul = fluid.layers.matmul( + input_emb_re, neg_emb_w_re, transpose_y=True) + neg_matmul_re = fluid.layers.reshape(neg_matmul, shape=[-1, neg_num]) + #neg_matmul_reshape = fluid.layers.reshape(neg_matmul, shape=[-1, ] + neg_logits = fluid.layers.elementwise_add(neg_matmul_re, neg_emb_b_vec) + #nce loss + + label_ones = fluid.layers.fill_constant_batch_size_like( + true_logits, shape=[-1, 1], value=1.0, dtype='float32') + label_zeros = fluid.layers.fill_constant_batch_size_like( + true_logits, shape=[-1, neg_num], value=0.0, dtype='float32') + + true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits, + label_ones) + neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits, + label_zeros) + + cost = fluid.layers.elementwise_add( + fluid.layers.reduce_sum( + true_xent, dim=1), + fluid.layers.reduce_sum( + neg_xent, dim=1)) + avg_cost = fluid.layers.reduce_mean(cost) + return avg_cost, py_reader + + +def infer_network(vocab_size, emb_size): + analogy_a = fluid.layers.data(name="analogy_a", shape=[1], dtype='int64') + analogy_b = fluid.layers.data(name="analogy_b", shape=[1], dtype='int64') + analogy_c = fluid.layers.data(name="analogy_c", shape=[1], dtype='int64') + all_label = fluid.layers.data( + name="all_label", + shape=[vocab_size, 1], + dtype='int64', + append_batch_size=False) + emb_all_label = fluid.layers.embedding( + input=all_label, size=[vocab_size, emb_size], param_attr="emb") + + emb_a = fluid.layers.embedding( + input=analogy_a, size=[vocab_size, emb_size], param_attr="emb") + emb_b = fluid.layers.embedding( + input=analogy_b, size=[vocab_size, emb_size], param_attr="emb") + emb_c = fluid.layers.embedding( + input=analogy_c, size=[vocab_size, emb_size], param_attr="emb") + target = fluid.layers.elementwise_add( + fluid.layers.elementwise_sub(emb_b, emb_a), emb_c) + emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) + dist = fluid.layers.matmul(x=target, y=emb_all_label_l2, transpose_y=True) + values, pred_idx = fluid.layers.topk(input=dist, k=4) + return values, pred_idx diff --git a/fluid/PaddleRec/word2vec/network_conf.py b/fluid/PaddleRec/word2vec/network_conf.py deleted file mode 100644 index 16178c33..00000000 --- a/fluid/PaddleRec/word2vec/network_conf.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -neural network for word2vec -""" - -from __future__ import print_function - -import math -import numpy as np - -import paddle.fluid as fluid - - -def skip_gram_word2vec(dict_size, - word_frequencys, - embedding_size, - max_code_length=None, - with_hsigmoid=False, - with_nce=True, - is_sparse=False): - def nce_layer(input, label, embedding_size, num_total_classes, - num_neg_samples, sampler, word_frequencys, sample_weight): - - w_param_name = "nce_w" - b_param_name = "nce_b" - w_param = fluid.default_main_program().global_block().create_parameter( - shape=[num_total_classes, embedding_size], - dtype='float32', - name=w_param_name) - b_param = fluid.default_main_program().global_block().create_parameter( - shape=[num_total_classes, 1], dtype='float32', name=b_param_name) - - cost = fluid.layers.nce(input=input, - label=label, - num_total_classes=num_total_classes, - sampler=sampler, - custom_dist=word_frequencys, - sample_weight=sample_weight, - param_attr=fluid.ParamAttr(name=w_param_name), - bias_attr=fluid.ParamAttr(name=b_param_name), - num_neg_samples=num_neg_samples, - is_sparse=is_sparse) - - return cost - - def hsigmoid_layer(input, label, path_table, path_code, non_leaf_num, - is_sparse): - if non_leaf_num is None: - non_leaf_num = dict_size - - cost = fluid.layers.hsigmoid( - input=input, - label=label, - num_classes=non_leaf_num, - path_table=path_table, - path_code=path_code, - is_custom=True, - is_sparse=is_sparse) - - return cost - - datas = [] - - input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64') - predict_word = fluid.layers.data( - name='predict_word', shape=[1], dtype='int64') - datas.append(input_word) - datas.append(predict_word) - - if with_hsigmoid: - path_table = fluid.layers.data( - name='path_table', - shape=[max_code_length if max_code_length else 40], - dtype='int64') - path_code = fluid.layers.data( - name='path_code', - shape=[max_code_length if max_code_length else 40], - dtype='int64') - datas.append(path_table) - datas.append(path_code) - - py_reader = fluid.layers.create_py_reader_by_data( - capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True) - - words = fluid.layers.read_file(py_reader) - target_emb = fluid.layers.embedding( - input=words[0], - is_sparse=is_sparse, - size=[dict_size, embedding_size], - param_attr=fluid.ParamAttr( - name='embeding', - initializer=fluid.initializer.Normal(scale=1 / - math.sqrt(dict_size)))) - context_emb = fluid.layers.embedding( - input=words[1], - is_sparse=is_sparse, - size=[dict_size, embedding_size], - param_attr=fluid.ParamAttr( - name='embeding', - initializer=fluid.initializer.Normal(scale=1 / - math.sqrt(dict_size)))) - cost, cost_nce, cost_hs = None, None, None - - if with_nce: - cost_nce = nce_layer(target_emb, words[1], embedding_size, dict_size, 5, - "uniform", word_frequencys, None) - cost = cost_nce - if with_hsigmoid: - cost_hs = hsigmoid_layer(context_emb, words[0], words[2], words[3], - dict_size, is_sparse) - cost = cost_hs - if with_nce and with_hsigmoid: - cost = fluid.layers.elementwise_add(cost_nce, cost_hs) - - avg_cost = fluid.layers.reduce_mean(cost) - - return avg_cost, py_reader diff --git a/fluid/PaddleRec/word2vec/preprocess.py b/fluid/PaddleRec/word2vec/preprocess.py index f13d3354..5af174e7 100644 --- a/fluid/PaddleRec/word2vec/preprocess.py +++ b/fluid/PaddleRec/word2vec/preprocess.py @@ -1,72 +1,55 @@ # -*- coding: utf-8 -* - +import os +import random import re import six import argparse import io - +import math prog = re.compile("[^a-z ]", flags=0) -word_count = dict() def parse_args(): parser = argparse.ArgumentParser( description="Paddle Fluid word2 vector preprocess") parser.add_argument( - '--data_path', - type=str, - required=True, - help="The path of training dataset") + '--build_dict_corpus_dir', type=str, help="The dir of corpus") + parser.add_argument( + '--input_corpus_dir', type=str, help="The dir of input corpus") + parser.add_argument( + '--output_corpus_dir', type=str, help="The dir of output corpus") parser.add_argument( '--dict_path', type=str, default='./dict', - help="The path of generated dict") + help="The path of dictionary ") parser.add_argument( - '--freq', + '--min_count', type=int, default=5, - help="If the word count is less then freq, it will be removed from dict") - + help="If the word count is less then min_count, it will be removed from dict" + ) + parser.add_argument( + '--downsample', + type=float, + default=0.001, + help="filter word by downsample") parser.add_argument( - '--with_other_dict', + '--filter_corpus', action='store_true', - required=False, default=False, - help='Using third party provided dict , (default: False)') - + help='Filter corpus') parser.add_argument( - '--other_dict_path', - type=str, - default='', - help='The path for third party provided dict (default: ' - ')') - + '--build_dict', + action='store_true', + default=False, + help='Build dict from corpus') return parser.parse_args() def text_strip(text): - return prog.sub("", text) - - -# users can self-define their own strip rules by modifing this method -def strip_lines(line, vocab=word_count): - return _replace_oov(vocab, native_to_unicode(line)) - - -# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py -def _replace_oov(original_vocab, line): - """Replace out-of-vocab words with "". - This maintains compatibility with published results. - Args: - original_vocab: a set of strings (The standard vocabulary for the dataset) - line: a unicode string - a space-delimited sequence of words. - Returns: - a unicode string - a space-delimited sequence of words. - """ - return u" ".join([ - word if word in original_vocab else u"" for word in line.split() - ]) + #English Preprocess Rule + return prog.sub("", text.lower()) # Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py @@ -98,147 +81,107 @@ def _to_unicode(s, ignore_errors=False): return s.decode("utf-8", errors=error_mode) -def build_Huffman(word_count, max_code_length): - - MAX_CODE_LENGTH = max_code_length - sorted_by_freq = sorted(word_count.items(), key=lambda x: x[1]) - count = list() - vocab_size = len(word_count) - parent = [-1] * 2 * vocab_size - code = [-1] * MAX_CODE_LENGTH - point = [-1] * MAX_CODE_LENGTH - binary = [-1] * 2 * vocab_size - word_code_len = dict() - word_code = dict() - word_point = dict() - i = 0 - for a in range(vocab_size): - count.append(word_count[sorted_by_freq[a][0]]) - - for a in range(vocab_size): - word_point[sorted_by_freq[a][0]] = [-1] * MAX_CODE_LENGTH - word_code[sorted_by_freq[a][0]] = [-1] * MAX_CODE_LENGTH - - for k in range(vocab_size): - count.append(1e15) - - pos1 = vocab_size - 1 - pos2 = vocab_size - min1i = 0 - min2i = 0 - b = 0 - - for r in range(vocab_size): - if pos1 >= 0: - if count[pos1] < count[pos2]: - min1i = pos1 - pos1 = pos1 - 1 - else: - min1i = pos2 - pos2 = pos2 + 1 - else: - min1i = pos2 - pos2 = pos2 + 1 - if pos1 >= 0: - if count[pos1] < count[pos2]: - min2i = pos1 - pos1 = pos1 - 1 - else: - min2i = pos2 - pos2 = pos2 + 1 - else: - min2i = pos2 - pos2 = pos2 + 1 - - count[vocab_size + r] = count[min1i] + count[min2i] - - #record the parent of left and right child - parent[min1i] = vocab_size + r - parent[min2i] = vocab_size + r - binary[min1i] = 0 #left branch has code 0 - binary[min2i] = 1 #right branch has code 1 - - for a in range(vocab_size): - b = a - i = 0 - while True: - code[i] = binary[b] - point[i] = b - i = i + 1 - b = parent[b] - if b == vocab_size * 2 - 2: - break - - word_code_len[sorted_by_freq[a][0]] = i - word_point[sorted_by_freq[a][0]][0] = vocab_size - 2 - - for k in range(i): - word_code[sorted_by_freq[a][0]][i - k - 1] = code[k] - - # only non-leaf nodes will be count in - if point[k] - vocab_size >= 0: - word_point[sorted_by_freq[a][0]][i - k] = point[k] - vocab_size - - return word_point, word_code, word_code_len - - -def preprocess(args): +def filter_corpus(args): + """ + filter corpus and convert id. + """ + word_count = dict() + word_to_id_ = dict() + word_all_count = 0 + id_counts = [] + word_id = 0 + #read dict + with io.open(args.dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, count = line.split()[0], int(line.split()[1]) + word_count[word] = count + word_to_id_[word] = word_id + word_id += 1 + id_counts.append(count) + word_all_count += count + + #filter corpus and convert id + if not os.path.exists(args.output_corpus_dir): + os.makedirs(args.output_corpus_dir) + for file in os.listdir(args.input_corpus_dir): + with io.open(args.output_corpus_dir + '/convert_' + file, "w") as wf: + with io.open( + args.input_corpus_dir + '/' + file, encoding='utf-8') as rf: + print(args.input_corpus_dir + '/' + file) + for line in rf: + signal = False + line = text_strip(line) + words = line.split() + for item in words: + if item in word_count: + idx = word_to_id_[item] + else: + idx = word_to_id_[native_to_unicode('')] + count_w = id_counts[idx] + corpus_size = word_all_count + keep_prob = ( + math.sqrt(count_w / + (args.downsample * corpus_size)) + 1 + ) * (args.downsample * corpus_size) / count_w + r_value = random.random() + if r_value > keep_prob: + continue + wf.write(_to_unicode(str(idx) + " ")) + signal = True + if signal: + wf.write(_to_unicode("\n")) + + +def build_dict(args): """ proprocess the data, generate dictionary and save into dict_path. - :param data_path: the input data path. + :param corpus_dir: the input data dir. :param dict_path: the generated dict path. the data in dict is "word count" - :param freq: + :param min_count: :return: """ # word to count - if args.with_other_dict: - with io.open(args.other_dict_path, 'r', encoding='utf-8') as f: - for line in f: - word_count[native_to_unicode(line.strip())] = 1 + word_count = dict() - for i in range(1, 100): + for file in os.listdir(args.build_dict_corpus_dir): with io.open( - args.data_path + "/news.en-000{:0>2d}-of-00100".format(i), - encoding='utf-8') as f: + args.build_dict_corpus_dir + "/" + file, encoding='utf-8') as f: + print("build dict : ", args.build_dict_corpus_dir + "/" + file) for line in f: - if args.with_other_dict: - line = strip_lines(line) - words = line.split() - for item in words: - if item in word_count: - word_count[item] = word_count[item] + 1 - else: - word_count[native_to_unicode('')] += 1 - else: - line = text_strip(line) - words = line.split() - for item in words: - if item in word_count: - word_count[item] = word_count[item] + 1 - else: - word_count[item] = 1 + line = text_strip(line) + words = line.split() + for item in words: + if item in word_count: + word_count[item] = word_count[item] + 1 + else: + word_count[item] = 1 + item_to_remove = [] for item in word_count: - if word_count[item] <= args.freq: + if word_count[item] <= args.min_count: item_to_remove.append(item) + + unk_sum = 0 for item in item_to_remove: + unk_sum += word_count[item] del word_count[item] - - path_table, path_code, word_code_len = build_Huffman(word_count, 40) + #sort by count + word_count[native_to_unicode('')] = unk_sum + word_count = sorted( + word_count.items(), key=lambda word_count: -word_count[1]) with io.open(args.dict_path, 'w+', encoding='utf-8') as f: - for k, v in word_count.items(): + for k, v in word_count: f.write(k + " " + str(v) + '\n') - with io.open(args.dict_path + "_ptable", 'w+', encoding='utf-8') as f2: - for pk, pv in path_table.items(): - f2.write(pk + '\t' + ' '.join((str(x) for x in pv)) + '\n') - - with io.open(args.dict_path + "_pcode", 'w+', encoding='utf-8') as f3: - for pck, pcv in path_code.items(): - f3.write(pck + '\t' + ' '.join((str(x) for x in pcv)) + '\n') - if __name__ == "__main__": - preprocess(parse_args()) + args = parse_args() + if args.build_dict: + build_dict(args) + elif args.filter_corpus: + filter_corpus(args) + else: + print( + "error command line, please choose --build_dict or --filter_corpus") diff --git a/fluid/PaddleRec/word2vec/reader.py b/fluid/PaddleRec/word2vec/reader.py index df479a4b..14a29caa 100644 --- a/fluid/PaddleRec/word2vec/reader.py +++ b/fluid/PaddleRec/word2vec/reader.py @@ -3,6 +3,8 @@ import numpy as np import preprocess import logging +import math +import random import io logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') @@ -39,17 +41,14 @@ class Word2VecReader(object): self.window_size_ = window_size self.data_path_ = data_path self.filelist = filelist - self.num_non_leaf = 0 self.word_to_id_ = dict() self.id_to_word = dict() self.word_count = dict() - self.word_to_path = dict() - self.word_to_code = dict() self.trainer_id = trainer_id self.trainer_num = trainer_num word_all_count = 0 - word_counts = [] + id_counts = [] word_id = 0 with io.open(dict_path, 'r', encoding='utf-8') as f: @@ -59,39 +58,31 @@ class Word2VecReader(object): self.word_to_id_[word] = word_id self.id_to_word[word_id] = word #build id to word dict word_id += 1 - word_counts.append(count) + id_counts.append(count) word_all_count += count + self.word_all_count = word_all_count + self.corpus_size_ = word_all_count + self.dict_size = len(self.word_to_id_) + self.id_counts_ = id_counts + #write word2id file + print("write word2id file to : " + dict_path + "_word_to_id_") with io.open(dict_path + "_word_to_id_", 'w+', encoding='utf-8') as f6: for k, v in self.word_to_id_.items(): f6.write(k + " " + str(v) + '\n') - self.dict_size = len(self.word_to_id_) - self.word_frequencys = [ - float(count) / word_all_count for count in word_counts + print("corpus_size:", self.corpus_size_) + self.id_frequencys = [ + float(count) / word_all_count for count in self.id_counts_ ] - print("dict_size = " + str(self.dict_size) + " word_all_count = " + str( - word_all_count)) - - with io.open(dict_path + "_ptable", 'r', encoding='utf-8') as f2: - for line in f2: - self.word_to_path[line.split('\t')[0]] = np.fromstring( - line.split('\t')[1], dtype=int, sep=' ') - self.num_non_leaf = np.fromstring( - line.split('\t')[1], dtype=int, sep=' ')[0] - print("word_ptable dict_size = " + str(len(self.word_to_path))) - - with io.open(dict_path + "_pcode", 'r', encoding='utf-8') as f3: - for line in f3: - self.word_to_code[line.split('\t')[0]] = np.fromstring( - line.split('\t')[1], dtype=int, sep=' ') - print("word_pcode dict_size = " + str(len(self.word_to_code))) + print("dict_size = " + str( + self.dict_size)) + " word_all_count = " + str(word_all_count) + self.random_generator = NumpyRandomInt(1, self.window_size_ + 1) def get_context_words(self, words, idx): """ Get the context word list of target word. - words: the words of the current line idx: input word index window_size: window size @@ -102,11 +93,10 @@ class Word2VecReader(object): start_point = 0 end_point = idx + target_window targets = words[start_point:idx] + words[idx + 1:end_point + 1] + return targets - return set(targets) - - def train(self, with_hs, with_other_dict): - def _reader(): + def train(self): + def nce_reader(): for file in self.filelist: with io.open( self.data_path_ + "/" + file, 'r', @@ -116,80 +106,12 @@ class Word2VecReader(object): count = 1 for line in f: if self.trainer_id == count % self.trainer_num: - if with_other_dict: - line = preprocess.strip_lines(line, - self.word_count) - else: - line = preprocess.text_strip(line) - word_ids = [ - self.word_to_id_[word] for word in line.split() - if word in self.word_to_id_ - ] + word_ids = [int(w) for w in line.split()] for idx, target_id in enumerate(word_ids): context_word_ids = self.get_context_words( word_ids, idx) for context_id in context_word_ids: yield [target_id], [context_id] - else: - pass - count += 1 - - def _reader_hs(): - for file in self.filelist: - with io.open( - self.data_path_ + "/" + file, 'r', - encoding='utf-8') as f: - logger.info("running data in {}".format(self.data_path_ + - "/" + file)) - count = 1 - for line in f: - if self.trainer_id == count % self.trainer_num: - if with_other_dict: - line = preprocess.strip_lines(line, - self.word_count) - else: - line = preprocess.text_strip(line) - word_ids = [ - self.word_to_id_[word] for word in line.split() - if word in self.word_to_id_ - ] - for idx, target_id in enumerate(word_ids): - context_word_ids = self.get_context_words( - word_ids, idx) - for context_id in context_word_ids: - yield [target_id], [context_id], [ - self.word_to_path[self.id_to_word[ - target_id]] - ], [ - self.word_to_code[self.id_to_word[ - target_id]] - ] - else: - pass count += 1 - if not with_hs: - return _reader - else: - return _reader_hs - - -if __name__ == "__main__": - window_size = 5 - - reader = Word2VecReader( - "./data/1-billion_dict", - "./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/", - ["news.en-00001-of-00100"], 0, 1) - - i = 0 - # print(reader.train(True)) - for x, y, z, f in reader.train(True)(): - print("x: " + str(x)) - print("y: " + str(y)) - print("path: " + str(z)) - print("code: " + str(f)) - print("\n") - if i == 10: - exit(0) - i += 1 + return nce_reader diff --git a/fluid/PaddleRec/word2vec/train.py b/fluid/PaddleRec/word2vec/train.py index 2c8512e4..99bae695 100644 --- a/fluid/PaddleRec/word2vec/train.py +++ b/fluid/PaddleRec/word2vec/train.py @@ -3,19 +3,14 @@ import argparse import logging import os import time - +import math +import random import numpy as np - -# disable gpu training for this example -os.environ["CUDA_VISIBLE_DEVICES"] = "" - import paddle import paddle.fluid as fluid -from paddle.fluid.executor import global_scope import six import reader -from network_conf import skip_gram_word2vec -from infer import inference_test +from net import skip_gram_word2vec logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger("fluid") @@ -26,25 +21,35 @@ def parse_args(): parser = argparse.ArgumentParser( description="PaddlePaddle Word2vec example") parser.add_argument( - '--train_data_path', + '--train_data_dir', type=str, - default='./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled', + default='./data/text', help="The path of taining dataset") + parser.add_argument( + '--base_lr', + type=float, + default=0.01, + help="The number of learing rate (default: 0.01)") + parser.add_argument( + '--save_step', + type=int, + default=500000, + help="The number of step to save (default: 500000)") + parser.add_argument( + '--print_batch', + type=int, + default=10, + help="The number of print_batch (default: 10)") parser.add_argument( '--dict_path', type=str, default='./data/1-billion_dict', help="The path of data dict") - parser.add_argument( - '--test_data_path', - type=str, - default='./data/text8', - help="The path of testing dataset") parser.add_argument( '--batch_size', type=int, - default=1000, - help="The size of mini-batch (default:100)") + default=500, + help="The size of mini-batch (default:500)") parser.add_argument( '--num_passes', type=int, @@ -55,90 +60,31 @@ def parse_args(): type=str, default='models', help='The path for model to store (default: models)') + parser.add_argument('--nce_num', type=int, default=5, help='nce_num') parser.add_argument( '--embedding_size', type=int, default=64, help='sparse feature hashing space for index processing') - - parser.add_argument( - '--with_hs', - action='store_true', - required=False, - default=False, - help='using hierarchical sigmoid, (default: False)') - - parser.add_argument( - '--with_nce', - action='store_true', - required=False, - default=False, - help='using negtive sampling, (default: True)') - - parser.add_argument( - '--max_code_length', - type=int, - default=40, - help='max code length used by hierarchical sigmoid, (default: 40)') - parser.add_argument( '--is_sparse', action='store_true', required=False, default=False, help='embedding and nce will use sparse or not, (default: False)') - - parser.add_argument( - '--with_Adam', - action='store_true', - required=False, - default=False, - help='Using Adam as optimizer or not, (default: False)') - - parser.add_argument( - '--is_local', - action='store_true', - required=False, - default=False, - help='Local train or not, (default: False)') - parser.add_argument( '--with_speed', action='store_true', required=False, default=False, help='print speed or not , (default: False)') - - parser.add_argument( - '--with_infer_test', - action='store_true', - required=False, - default=False, - help='Do inference every 100 batches , (default: False)') - - parser.add_argument( - '--with_other_dict', - action='store_true', - required=False, - default=False, - help='if use other dict , (default: False)') - - parser.add_argument( - '--rank_num', - type=int, - default=4, - help="find rank_num-nearest result for test (default: 4)") - return parser.parse_args() -def convert_python_to_tensor(batch_size, sample_reader, is_hs): +def convert_python_to_tensor(weight, batch_size, sample_reader): def __reader__(): - result = None - if is_hs: - result = [[], [], [], []] - else: - result = [[], []] + cs = np.array(weight).cumsum() + result = [[], []] for sample in sample_reader(): for i, fea in enumerate(sample): result[i].append(fea) @@ -152,27 +98,27 @@ def convert_python_to_tensor(batch_size, sample_reader, is_hs): elif len(dat.shape) == 1: dat = dat.reshape((-1, 1)) t.set(dat, fluid.CPUPlace()) - tensor_result.append(t) + tt = fluid.Tensor() + neg_array = cs.searchsorted(np.random.sample(args.nce_num)) + neg_array = np.tile(neg_array, batch_size) + tt.set( + neg_array.reshape((batch_size, args.nce_num)), + fluid.CPUPlace()) + tensor_result.append(tt) yield tensor_result - if is_hs: - result = [[], [], [], []] - else: - result = [[], []] + result = [[], []] return __reader__ -def train_loop(args, train_program, reader, py_reader, loss, trainer_id): +def train_loop(args, train_program, reader, py_reader, loss, trainer_id, + weight): py_reader.decorate_tensor_provider( - convert_python_to_tensor(args.batch_size, - reader.train((args.with_hs or ( - not args.with_nce)), args.with_other_dict), - (args.with_hs or (not args.with_nce)))) + convert_python_to_tensor(weight, args.batch_size, reader.train())) place = fluid.CPUPlace() - exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) @@ -193,50 +139,38 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): build_strategy=build_strategy, exec_strategy=exec_strategy) - profile_state = "CPU" - profiler_step = 0 - profiler_step_start = 20 - profiler_step_end = 30 - for pass_id in range(args.num_passes): py_reader.start() time.sleep(10) epoch_start = time.time() batch_id = 0 start = time.time() - try: while True: loss_val = train_exe.run(fetch_list=[loss.name]) loss_val = np.mean(loss_val) - if batch_id % 50 == 0: + if batch_id % args.print_batch == 0: logger.info( "TRAIN --> pass: {} batch: {} loss: {} reader queue:{}". format(pass_id, batch_id, loss_val.mean(), py_reader.queue.size())) if args.with_speed: - if batch_id % 1000 == 0 and batch_id != 0: + if batch_id % 500 == 0 and batch_id != 0: elapsed = (time.time() - start) start = time.time() samples = 1001 * args.batch_size * int( os.getenv("CPU_NUM")) logger.info("Time used: {}, Samples/Sec: {}".format( elapsed, samples / elapsed)) - # calculate infer result each 100 batches when using --with_infer_test - if args.with_infer_test: - if batch_id % 1000 == 0 and batch_id != 0: - model_dir = args.model_output_dir + '/batch-' + str( - batch_id) - inference_test(global_scope(), model_dir, args) - - if batch_id % 500000 == 0 and batch_id != 0: - model_dir = args.model_output_dir + '/batch-' + str( - batch_id) - fluid.io.save_persistables(executor=exe, dirname=model_dir) - with open(model_dir + "/_success", 'w+') as f: - f.write(str(batch_id)) + + if batch_id % args.save_step == 0 and batch_id != 0: + model_dir = args.model_output_dir + '/pass-' + str( + pass_id) + ('/batch-' + str(batch_id)) + if trainer_id == 0: + fluid.io.save_params(executor=exe, dirname=model_dir) + print("model saved in %s" % model_dir) batch_id += 1 except fluid.core.EOFException: @@ -244,12 +178,10 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): epoch_end = time.time() logger.info("Epoch: {0}, Train total expend: {1} ".format( pass_id, epoch_end - epoch_start)) - model_dir = args.model_output_dir + '/pass-' + str(pass_id) if trainer_id == 0: - fluid.io.save_persistables(executor=exe, dirname=model_dir) - with open(model_dir + "/_success", 'w+') as f: - f.write(str(pass_id)) + fluid.io.save_params(executor=exe, dirname=model_dir) + print("model saved in %s" % model_dir) def GetFileList(data_path): @@ -261,123 +193,42 @@ def train(args): if not os.path.isdir(args.model_output_dir): os.mkdir(args.model_output_dir) - filelist = GetFileList(args.train_data_path) - word2vec_reader = None - if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1": - word2vec_reader = reader.Word2VecReader( - args.dict_path, args.train_data_path, filelist, 0, 1) - else: - trainer_id = int(os.environ["PADDLE_TRAINER_ID"]) - trainer_num = int(os.environ["PADDLE_TRAINERS"]) - word2vec_reader = reader.Word2VecReader(args.dict_path, - args.train_data_path, filelist, - trainer_id, trainer_num) + filelist = GetFileList(args.train_data_dir) + word2vec_reader = reader.Word2VecReader(args.dict_path, args.train_data_dir, + filelist, 0, 1) logger.info("dict_size: {}".format(word2vec_reader.dict_size)) + #update id_frequency + sum_val = 0.0 + id_frequencys_pow = [] + for id_x in range(len(word2vec_reader.id_frequencys)): + id_frequencys_pow.append( + math.pow(word2vec_reader.id_frequencys[id_x], 0.75)) + sum_val += id_frequencys_pow[id_x] + id_frequencys_pow = [val / sum_val for val in id_frequencys_pow] + loss, py_reader = skip_gram_word2vec( word2vec_reader.dict_size, - word2vec_reader.word_frequencys, + id_frequencys_pow, args.embedding_size, - args.max_code_length, - args.with_hs, - args.with_nce, is_sparse=args.is_sparse) - optimizer = None - if args.with_Adam: - optimizer = fluid.optimizer.Adam(learning_rate=1e-4, lazy_mode=True) - else: - optimizer = fluid.optimizer.SGD(learning_rate=1e-4) + optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.exponential_decay( + learning_rate=args.base_lr, + decay_steps=100000, + decay_rate=0.999, + staircase=True)) optimizer.minimize(loss) # do local training - if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1": - logger.info("run local training") - main_program = fluid.default_main_program() - - with open("local.main.proto", "w") as f: - f.write(str(main_program)) - - train_loop(args, main_program, word2vec_reader, py_reader, loss, 0) - # do distribute training - else: - logger.info("run dist training") - - trainer_id = int(os.environ["PADDLE_TRAINER_ID"]) - trainers = int(os.environ["PADDLE_TRAINERS"]) - training_role = os.environ["PADDLE_TRAINING_ROLE"] - - port = os.getenv("PADDLE_PSERVER_PORT", "6174") - pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "") - eplist = [] - for ip in pserver_ips.split(","): - eplist.append(':'.join([ip, port])) - pserver_endpoints = ",".join(eplist) - current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port - - config = fluid.DistributeTranspilerConfig() - config.slice_var_up = False - t = fluid.DistributeTranspiler(config=config) - t.transpile( - trainer_id, - pservers=pserver_endpoints, - trainers=trainers, - sync_mode=True) - - if training_role == "PSERVER": - logger.info("run pserver") - prog = t.get_pserver_program(current_endpoint) - startup = t.get_startup_program( - current_endpoint, pserver_program=prog) - - with open("pserver.main.proto.{}".format(os.getenv("CUR_PORT")), - "w") as f: - f.write(str(prog)) - - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(startup) - exe.run(prog) - elif training_role == "TRAINER": - logger.info("run trainer") - train_prog = t.get_trainer_program() - - with open("trainer.main.proto.{}".format(trainer_id), "w") as f: - f.write(str(train_prog)) - - train_loop(args, train_prog, word2vec_reader, py_reader, loss, - trainer_id) - - -def env_declar(): - print("******** Rename Cluster Env to PaddleFluid Env ********") - - print("Content-Type: text/plain\n\n") - for key in os.environ.keys(): - print("%30s %s \n" % (key, os.environ[key])) - - if os.environ["TRAINING_ROLE"] == "PSERVER" or os.environ[ - "PADDLE_IS_LOCAL"] == "0": - os.environ["PADDLE_TRAINING_ROLE"] = os.environ["TRAINING_ROLE"] - os.environ["PADDLE_PSERVER_PORT"] = os.environ["PADDLE_PORT"] - os.environ["PADDLE_PSERVER_IPS"] = os.environ["PADDLE_PSERVERS"] - os.environ["PADDLE_TRAINERS"] = os.environ["PADDLE_TRAINERS_NUM"] - os.environ["PADDLE_CURRENT_IP"] = os.environ["POD_IP"] - os.environ["PADDLE_TRAINER_ID"] = os.environ["PADDLE_TRAINER_ID"] - # we set the thread number same as CPU number - os.environ["CPU_NUM"] = "12" - - print("Content-Type: text/plain\n\n") - for key in os.environ.keys(): - print("%30s %s \n" % (key, os.environ[key])) - - print("****** Rename Cluster Env to PaddleFluid Env END ******") + logger.info("run local training") + main_program = fluid.default_main_program() + train_loop(args, main_program, word2vec_reader, py_reader, loss, 0, + word2vec_reader.id_frequencys) if __name__ == '__main__': args = parse_args() - if args.is_local: - pass - else: - env_declar() train(args) diff --git a/fluid/PaddleRec/word2vec/utils.py b/fluid/PaddleRec/word2vec/utils.py new file mode 100644 index 00000000..01cd04e4 --- /dev/null +++ b/fluid/PaddleRec/word2vec/utils.py @@ -0,0 +1,96 @@ +import sys +import collections +import six +import time +import numpy as np +import paddle.fluid as fluid +import paddle +import os +import preprocess + + +def BuildWord_IdMap(dict_path): + word_to_id = dict() + id_to_word = dict() + with open(dict_path, 'r') as f: + for line in f: + word_to_id[line.split(' ')[0]] = int(line.split(' ')[1]) + id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] + return word_to_id, id_to_word + + +def prepare_data(file_dir, dict_path, batch_size): + w2i, i2w = BuildWord_IdMap(dict_path) + vocab_size = len(i2w) + reader = paddle.batch(test(file_dir, w2i), batch_size) + return vocab_size, reader, i2w + + +def native_to_unicode(s): + if _is_unicode(s): + return s + try: + return _to_unicode(s) + except UnicodeDecodeError: + res = _to_unicode(s, ignore_errors=True) + return res + + +def _is_unicode(s): + if six.PY2: + if isinstance(s, unicode): + return True + else: + if isinstance(s, str): + return True + return False + + +def _to_unicode(s, ignore_errors=False): + if _is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) + + +def strip_lines(line, vocab): + return _replace_oov(vocab, native_to_unicode(line)) + + +def _replace_oov(original_vocab, line): + """Replace out-of-vocab words with "". + This maintains compatibility with published results. + Args: + original_vocab: a set of strings (The standard vocabulary for the dataset) + line: a unicode string - a space-delimited sequence of words. + Returns: + a unicode string - a space-delimited sequence of words. + """ + return u" ".join([ + word if word in original_vocab else u"" for word in line.split() + ]) + + +def reader_creator(file_dir, word_to_id): + def reader(): + files = os.listdir(file_dir) + for fi in files: + with open(file_dir + '/' + fi, "r") as f: + for line in f: + if ':' in line: + pass + else: + line = strip_lines(line.lower(), word_to_id) + line = line.split() + yield [word_to_id[line[0]]], [word_to_id[line[1]]], [ + word_to_id[line[2]] + ], [word_to_id[line[3]]], [ + word_to_id[line[0]], word_to_id[line[1]], + word_to_id[line[2]] + ] + + return reader + + +def test(test_dir, w2i): + return reader_creator(test_dir, w2i) -- GitLab