diff --git a/fluid/PaddleRec/word2vec/README.md b/fluid/PaddleRec/word2vec/README.md index 812fe1e33f024989a768e48c8622984eecd95d28..ec39996fd763dd90791a27da5dee99db2c131a13 100644 --- a/fluid/PaddleRec/word2vec/README.md +++ b/fluid/PaddleRec/word2vec/README.md @@ -4,6 +4,9 @@ ```text . +├── cluster_train.py # 分布式训练函数 +├── cluster_train.sh # 本地模拟多机脚本 +├── train.py # 训练函数 ├── infer.py # 预测脚本 ├── net.py # 网络结构 ├── preprocess.py # 预处理脚本,包括构建词典和预处理文本 @@ -18,56 +21,87 @@ 本例实现了skip-gram模式的word2vector模型。 -## 数据集 -大数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark)的数据集.下载命令如下 +## 数据下载 +全量数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark) 的数据集. + +```bash +wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz +``` + +备用数据地址下载命令如下 ```bash -wget https://paddle-zwh.bj.bcebos.com/1-billion-word-language-modeling-benchmark-r13output.tar +wget https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar ``` -小数据集使用1700w个词的text8数据集,下载命令如下 +为了方便快速验证,我们也提供了经典的text8样例数据集,包含1700w个词。 下载命令如下 -下载数据集: ```bash -wget https://paddle-zwh.bj.bcebos.com/text.tar +wget https://paddlerec.bj.bcebos.com/word2vec/text.tar ``` ## 数据预处理 -以下以小数据为例进行预处理。 +以样例数据集为例进行预处理。全量数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和样例数据集的text目录并列。 + +词典格式: 词<空格>词频。注意低频词用''表示 -大数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和小数据集的text目录并列。 +可以按格式自建词典,如果自建词典跳过第一步。 +``` +the 1061396 +of 593677 +and 416629 +one 411764 +in 372201 +a 325873 + 324608 +to 316376 +zero 264975 +nine 250430 +``` -根据英文语料生成词典, 中文语料可以通过修改text_strip +第一步根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。 ```bash python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict ``` -根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词。 +第二步根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词。 ```bash python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text/ --output_corpus_dir data/convert_text8 --min_count 5 --downsample 0.001 ``` ## 训练 -cpu 单机多线程训练 +具体的参数配置可运行 + +```bash +python train.py -h +``` + +单机多线程训练 ```bash OPENBLAS_NUM_THREADS=1 CPU_NUM=5 python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes 10 --batch_size 100 --model_output_dir v1_cpu5_b100_lr1dir --base_lr 1.0 --print_batch 1000 --with_speed --is_sparse ``` +本地单机模拟多机训练 + +```bash +sh cluster_train.sh +``` + ## 预测 测试集下载命令如下 ```bash -#大数据集测试集 -wget https://paddle-zwh.bj.bcebos.com/test_dir.tar -#小数据集测试集 -wget https://paddle-zwh.bj.bcebos.com/test_mid_dir.tar +#全量数据集测试集 +wget https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar +#样本数据集测试集 +wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar ``` -预测命令 +预测命令,注意词典名称需要加后缀"_word_to_id_", 此文件是训练阶段生成的。 ```bash python infer.py --infer_epoch --test_dir data/test_mid_dir/ --dict_path data/test_build_dict_word_to_id_ --batch_size 20000 --model_dir v1_cpu5_b100_lr1dir/ --start_index 0 ``` diff --git a/fluid/PaddleRec/word2vec/cluster_train.py b/fluid/PaddleRec/word2vec/cluster_train.py new file mode 100644 index 0000000000000000000000000000000000000000..66e95511df32fd6f08dd268bd1c78e1cfe2b79ab --- /dev/null +++ b/fluid/PaddleRec/word2vec/cluster_train.py @@ -0,0 +1,250 @@ +from __future__ import print_function +import argparse +import logging +import os +import time +import math +import random +import numpy as np +import paddle +import paddle.fluid as fluid +import six +import reader +from net import skip_gram_word2vec + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="PaddlePaddle Word2vec example") + parser.add_argument( + '--train_data_dir', + type=str, + default='./data/text', + help="The path of taining dataset") + parser.add_argument( + '--base_lr', + type=float, + default=0.01, + help="The number of learing rate (default: 0.01)") + parser.add_argument( + '--save_step', + type=int, + default=500000, + help="The number of step to save (default: 500000)") + parser.add_argument( + '--print_batch', + type=int, + default=100, + help="The number of print_batch (default: 10)") + parser.add_argument( + '--dict_path', + type=str, + default='./data/1-billion_dict', + help="The path of data dict") + parser.add_argument( + '--batch_size', + type=int, + default=500, + help="The size of mini-batch (default:500)") + parser.add_argument( + '--num_passes', + type=int, + default=10, + help="The number of passes to train (default: 10)") + parser.add_argument( + '--model_output_dir', + type=str, + default='models', + help='The path for model to store (default: models)') + parser.add_argument('--nce_num', type=int, default=5, help='nce_num') + parser.add_argument( + '--embedding_size', + type=int, + default=64, + help='sparse feature hashing space for index processing') + parser.add_argument( + '--is_sparse', + action='store_true', + required=False, + default=False, + help='embedding and nce will use sparse or not, (default: False)') + parser.add_argument( + '--with_speed', + action='store_true', + required=False, + default=False, + help='print speed or not , (default: False)') + parser.add_argument( + '--role', type=str, default='pserver', help='trainer or pserver') + parser.add_argument( + '--endpoints', + type=str, + default='127.0.0.1:6000', + help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001') + parser.add_argument( + '--current_endpoint', + type=str, + default='127.0.0.1:6000', + help='The current_endpoint') + parser.add_argument( + '--trainer_id', + type=int, + default=0, + help='trainer id ,only trainer_id=0 save model') + parser.add_argument( + '--trainers', + type=int, + default=1, + help='The num of trianers, (default: 1)') + return parser.parse_args() + + +def convert_python_to_tensor(weight, batch_size, sample_reader): + def __reader__(): + cs = np.array(weight).cumsum() + result = [[], []] + for sample in sample_reader(): + for i, fea in enumerate(sample): + result[i].append(fea) + if len(result[0]) == batch_size: + tensor_result = [] + for tensor in result: + t = fluid.Tensor() + dat = np.array(tensor, dtype='int64') + if len(dat.shape) > 2: + dat = dat.reshape((dat.shape[0], dat.shape[2])) + elif len(dat.shape) == 1: + dat = dat.reshape((-1, 1)) + t.set(dat, fluid.CPUPlace()) + tensor_result.append(t) + tt = fluid.Tensor() + neg_array = cs.searchsorted(np.random.sample(args.nce_num)) + neg_array = np.tile(neg_array, batch_size) + tt.set( + neg_array.reshape((batch_size, args.nce_num)), + fluid.CPUPlace()) + tensor_result.append(tt) + yield tensor_result + result = [[], []] + + return __reader__ + + +def train_loop(args, train_program, reader, py_reader, loss, trainer_id, + weight): + + py_reader.decorate_tensor_provider( + convert_python_to_tensor(weight, args.batch_size, reader.train())) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + print("CPU_NUM:" + str(os.getenv("CPU_NUM"))) + + train_exe = exe + + for pass_id in range(args.num_passes): + py_reader.start() + time.sleep(10) + epoch_start = time.time() + batch_id = 0 + start = time.time() + try: + while True: + + loss_val = train_exe.run(fetch_list=[loss.name]) + loss_val = np.mean(loss_val) + + if batch_id % args.print_batch == 0: + logger.info( + "TRAIN --> pass: {} batch: {} loss: {} reader queue:{}". + format(pass_id, batch_id, + loss_val.mean(), py_reader.queue.size())) + if args.with_speed: + if batch_id % 500 == 0 and batch_id != 0: + elapsed = (time.time() - start) + start = time.time() + samples = 1001 * args.batch_size * int( + os.getenv("CPU_NUM")) + logger.info("Time used: {}, Samples/Sec: {}".format( + elapsed, samples / elapsed)) + + if batch_id % args.save_step == 0 and batch_id != 0: + model_dir = args.model_output_dir + '/pass-' + str( + pass_id) + ('/batch-' + str(batch_id)) + if trainer_id == 0: + fluid.io.save_params(executor=exe, dirname=model_dir) + print("model saved in %s" % model_dir) + batch_id += 1 + + except fluid.core.EOFException: + py_reader.reset() + epoch_end = time.time() + logger.info("Epoch: {0}, Train total expend: {1} ".format( + pass_id, epoch_end - epoch_start)) + model_dir = args.model_output_dir + '/pass-' + str(pass_id) + if trainer_id == 0: + fluid.io.save_params(executor=exe, dirname=model_dir) + print("model saved in %s" % model_dir) + + +def GetFileList(data_path): + return os.listdir(data_path) + + +def train(args): + + if not os.path.isdir(args.model_output_dir) and args.train_id == 0: + os.mkdir(args.model_output_dir) + + filelist = GetFileList(args.train_data_dir) + word2vec_reader = reader.Word2VecReader(args.dict_path, args.train_data_dir, + filelist, 0, 1) + + logger.info("dict_size: {}".format(word2vec_reader.dict_size)) + np_power = np.power(np.array(word2vec_reader.id_frequencys), 0.75) + id_frequencys_pow = np_power / np_power.sum() + + loss, py_reader = skip_gram_word2vec( + word2vec_reader.dict_size, + args.embedding_size, + is_sparse=args.is_sparse, + neg_num=args.nce_num) + + optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.exponential_decay( + learning_rate=args.base_lr, + decay_steps=100000, + decay_rate=0.999, + staircase=True)) + + optimizer.minimize(loss) + + logger.info("run dist training") + + t = fluid.DistributeTranspiler() + t.transpile( + args.trainer_id, pservers=args.endpoints, trainers=args.trainers) + if args.role == "pserver": + print("run psever") + pserver_prog = t.get_pserver_program(args.current_endpoint) + pserver_startup = t.get_startup_program(args.current_endpoint, + pserver_prog) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(pserver_startup) + exe.run(pserver_prog) + elif args.role == "trainer": + print("run trainer") + train_loop(args, + t.get_trainer_program(), word2vec_reader, py_reader, loss, + args.trainer_id, id_frequencys_pow) + + +if __name__ == '__main__': + args = parse_args() + train(args) diff --git a/fluid/PaddleRec/word2vec/cluster_train.sh b/fluid/PaddleRec/word2vec/cluster_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..756196fd41eeb52d5f43553664c824748ac83e4e --- /dev/null +++ b/fluid/PaddleRec/word2vec/cluster_train.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +#export GLOG_v=30 +#export GLOG_logtostderr=1 + +# start pserver0 +export CPU_NUM=5 +export FLAGS_rpc_deadline=3000000 +python cluster_train.py \ + --train_data_dir data/convert_text8 \ + --dict_path data/test_build_dict \ + --batch_size 100 \ + --model_output_dir dis_model \ + --base_lr 1.0 \ + --print_batch 1 \ + --is_sparse \ + --with_speed \ + --role pserver \ + --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ + --current_endpoint 127.0.0.1:6000 \ + --trainers 2 \ + > pserver0.log 2>&1 & + +python cluster_train.py \ + --train_data_dir data/convert_text8 \ + --dict_path data/test_build_dict \ + --batch_size 100 \ + --model_output_dir dis_model \ + --base_lr 1.0 \ + --print_batch 1 \ + --is_sparse \ + --with_speed \ + --role pserver \ + --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ + --current_endpoint 127.0.0.1:6001 \ + --trainers 2 \ + > pserver1.log 2>&1 & + +# start trainer0 +python cluster_train.py \ + --train_data_dir data/convert_text8 \ + --dict_path data/test_build_dict \ + --batch_size 100 \ + --model_output_dir dis_model \ + --base_lr 1.0 \ + --print_batch 1000 \ + --is_sparse \ + --with_speed \ + --role trainer \ + --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ + --trainers 2 \ + --trainer_id 0 \ + > trainer0.log 2>&1 & +# start trainer1 +python cluster_train.py \ + --train_data_dir data/convert_text8 \ + --dict_path data/test_build_dict \ + --batch_size 100 \ + --model_output_dir dis_model \ + --base_lr 1.0 \ + --print_batch 1000 \ + --is_sparse \ + --with_speed \ + --role trainer \ + --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ + --trainers 2 \ + --trainer_id 1 \ + > trainer1.log 2>&1 & diff --git a/fluid/PaddleRec/word2vec/net.py b/fluid/PaddleRec/word2vec/net.py index b900c62a35844d8aa9706ccda533eae8a7705f34..1945b62a44925154583aadca173883359378a02a 100644 --- a/fluid/PaddleRec/word2vec/net.py +++ b/fluid/PaddleRec/word2vec/net.py @@ -20,17 +20,13 @@ import numpy as np import paddle.fluid as fluid -def skip_gram_word2vec(dict_size, - word_frequencys, - embedding_size, - is_sparse=False): +def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5): datas = [] - neg_num = 5 input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64') true_word = fluid.layers.data(name='true_label', shape=[1], dtype='int64') neg_word = fluid.layers.data( - name="neg_label", shape=[neg_num, 5], dtype='int64') + name="neg_label", shape=[neg_num], dtype='int64') datas.append(input_word) datas.append(true_word) @@ -70,8 +66,7 @@ def skip_gram_word2vec(dict_size, is_sparse=is_sparse, size=[dict_size, embedding_size], param_attr=fluid.ParamAttr( - name='emb_w', learning_rate=10.0)) - # param_attr='emb_w') + name='emb_w', learning_rate=1.0)) neg_emb_w_re = fluid.layers.reshape( neg_emb_w, shape=[-1, neg_num, embedding_size]) @@ -80,8 +75,7 @@ def skip_gram_word2vec(dict_size, is_sparse=is_sparse, size=[dict_size, 1], param_attr=fluid.ParamAttr( - name='emb_b', learning_rate=10.0)) - #param_attr='emb_b') + name='emb_b', learning_rate=1.0)) neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num]) true_logits = fluid.layers.elementwise_add( @@ -95,7 +89,6 @@ def skip_gram_word2vec(dict_size, neg_matmul = fluid.layers.matmul( input_emb_re, neg_emb_w_re, transpose_y=True) neg_matmul_re = fluid.layers.reshape(neg_matmul, shape=[-1, neg_num]) - #neg_matmul_reshape = fluid.layers.reshape(neg_matmul, shape=[-1, ] neg_logits = fluid.layers.elementwise_add(neg_matmul_re, neg_emb_b_vec) #nce loss diff --git a/fluid/PaddleRec/word2vec/train.py b/fluid/PaddleRec/word2vec/train.py index 99bae695d1579b5c38a402b8ebfe33c34629a5b0..430ec132d2f810eed0025f16e9b87a8f742c455c 100644 --- a/fluid/PaddleRec/word2vec/train.py +++ b/fluid/PaddleRec/word2vec/train.py @@ -198,20 +198,14 @@ def train(args): filelist, 0, 1) logger.info("dict_size: {}".format(word2vec_reader.dict_size)) - #update id_frequency - sum_val = 0.0 - id_frequencys_pow = [] - for id_x in range(len(word2vec_reader.id_frequencys)): - id_frequencys_pow.append( - math.pow(word2vec_reader.id_frequencys[id_x], 0.75)) - sum_val += id_frequencys_pow[id_x] - id_frequencys_pow = [val / sum_val for val in id_frequencys_pow] + np_power = np.power(np.array(word2vec_reader.id_frequencys), 0.75) + id_frequencys_pow = np_power / np_power.sum() loss, py_reader = skip_gram_word2vec( word2vec_reader.dict_size, - id_frequencys_pow, args.embedding_size, - is_sparse=args.is_sparse) + is_sparse=args.is_sparse, + neg_num=args.nce_num) optimizer = fluid.optimizer.SGD( learning_rate=fluid.layers.exponential_decay( @@ -226,7 +220,7 @@ def train(args): logger.info("run local training") main_program = fluid.default_main_program() train_loop(args, main_program, word2vec_reader, py_reader, loss, 0, - word2vec_reader.id_frequencys) + id_frequencys_pow) if __name__ == '__main__':