diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index 278df79f4d6e116a19e106ad141e417779f6c02d..239b568be34793c5ddb0830e9cca06951da143f4 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -68,6 +68,8 @@ class DataLoader(DatasetBase): reader_ins = SlotReader(context["config_yaml"]) if hasattr(reader_ins, 'generate_batch_from_trainfiles'): dataloader.set_sample_list_generator(reader) + elif hasattr(reader_ins, 'batch_tensor_creator'): + dataloader.set_batch_generator(reader) else: dataloader.set_sample_generator(reader, batch_size) return dataloader diff --git a/core/utils/dataloader_instance.py b/core/utils/dataloader_instance.py index f484626a73f69481f3ae51b35fc5b6e717870938..03e6f0a67884917e9af2d02d13eb86576620ceef 100755 --- a/core/utils/dataloader_instance.py +++ b/core/utils/dataloader_instance.py @@ -83,6 +83,10 @@ def dataloader_by_name(readerclass, if hasattr(reader, 'generate_batch_from_trainfiles'): return gen_batch_reader() + + if hasattr(reader, "batch_tensor_creator"): + return reader.batch_tensor_creator(gen_reader) + return gen_reader diff --git a/doc/imgs/w2v_train.png b/doc/imgs/w2v_train.png new file mode 100644 index 0000000000000000000000000000000000000000..b32aa14327003e138ec7ccbf035866c4bf73edf7 Binary files /dev/null and b/doc/imgs/w2v_train.png differ diff --git a/models/recall/word2vec/README.md b/models/recall/word2vec/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7f6ed7522801822b9281df84a6a6924f30eb2f87 --- /dev/null +++ b/models/recall/word2vec/README.md @@ -0,0 +1,248 @@ +# Skip-Gram W2V + +以下是本例的简要目录结构及说明: + +``` +├── data #样例数据 + ├── train + ├── convert_sample.txt + ├── test + ├── sample.txt + ├── dict + ├── word_count_dict.txt + ├── word_id_dict.txt +├── preprocess.py # 数据预处理文件 +├── __init__.py +├── README.md # 文档 +├── model.py #模型文件 +├── config.yaml #配置文件 +├── data_prepare.sh #一键数据处理脚本 +├── w2v_reader.py #训练数据reader +├── w2v_evaluate_reader.py # 预测数据reader +├── infer.py # 自定义预测脚本 +├── utils.py # 自定义预测中用到的reader等工具 +``` + +注:在阅读该示例前,建议您先了解以下内容: + +[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md) + + +--- +## 内容 + +- [模型简介](#模型简介) +- [数据准备](#数据准备) +- [运行环境](#运行环境) +- [快速开始](#快速开始) +- [论文复现](#论文复现) +- [进阶使用](#进阶使用) +- [FAQ](#FAQ) + +## 模型简介 +本例实现了skip-gram模式的word2vector模型,如下图所示: +

+ +

+以每一个词为中心词X,然后在窗口内和临近的词Y组成样本对(X,Y)用于网络训练。在实际训练过程中还会根据自定义的负采样率生成负样本来加强训练的效果 +具体的训练思路如下: +

+ +

+ +推荐用户参考[ IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/124377)教程获取更详细的信息。 + +本模型配置默认使用demo数据集,若进行精度验证,请参考[论文复现](#论文复现)部分。 + +本项目支持功能 + +训练:单机CPU、本地模拟参数服务器训练、增量训练,配置请参考 [启动训练](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/train.md) + +预测:单机CPU;配置请参考[PaddleRec 离线预测](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/predict.md) + +## 数据处理 +为和样例数据路径区分,全量训练数据、测试数据、词表文件会依次保存在data/all_train, data/all_test, data/all_dict文件夹中。 +``` +mkdir -p data/all_dict +mkdir -p data/all_train +mkdir -p data/all_test +``` +本示例中全量数据处理共包含三步: +- Step1: 数据下载。 + ``` + # 全量训练集 + mkdir raw_data + wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar + tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar + mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ raw_data/ + + # 测试集 + wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar + tar xzvf test_dir.tar -C raw_data + mv raw_data/data/test_dir/* data/all_test/ + ``` + +- Step2: 训练据预处理。包含三步,第一步,根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。 + ``` + python preprocess.py --build_dict --build_dict_corpus_dir raw_data/training-monolingual.tokenized.shuffled --dict_path raw_data/word_count_dict.txt + ``` + 得到的词典格式为词<空格>词频,低频词用'UNK'表示,如下所示: + ``` + the 1061396 + of 593677 + and 416629 + one 411764 + in 372201 + a 325873 + 324608 + to 316376 + zero 264975 + nine 250430 + ``` + 第二步,根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词, 同时生成word和id映射的文件,文件名为词典+"word_to_id"。 + ``` + python preprocess.py --filter_corpus --dict_path raw_data/word_count_dict.txt --input_corpus_dir raw_data/training-monolingual.tokenized.shuffled --output_corpus_dir raw_data/convert_text8 --min_count 5 --downsample 0.001 + ``` + 第三步,为更好地利用多线程进行训练加速,我们需要将训练文件分成多个子文件,默认拆分成1024个文件。 + ``` + python preprocess.py --data_resplit --input_corpus_dir=raw_data/convert_text8 --output_corpus_dir=data/all_train + ``` +- Step3: 路径整理。 + ``` + mv raw_data/word_count_dict.txt data/all_dict/ + mv raw_data/word_count_dict.txt_word_to_id_ data/all_dict/word_id_dict.txt + rm -rf raw_data + ``` +方便起见, 我们提供了一键式数据处理脚本: +``` +sh data_prepare.sh +``` + +## 运行环境 + +PaddlePaddle>=1.7.2 + +python 2.7/3.5/3.6/3.7 + +PaddleRec >=0.1 + +os : windows/linux/macos + +## 快速开始 + +### 单机训练 + +CPU环境 + +在config.yaml文件中设置好设备,epochs等。 + +``` +# select runner by name +mode: [single_cpu_train, single_cpu_infer] +# config of each runner. +# runner is a kind of paddle training class, which wraps the train/infer process. +runner: +- name: single_cpu_train + class: train + # num of epochs + epochs: 5 + # device to run training or infer + device: cpu + save_checkpoint_interval: 1 # save model interval of epochs + save_inference_interval: 1 # save inference + save_checkpoint_path: "increment_w2v" # save checkpoint path + save_inference_path: "inference_w2v" # save inference path + save_inference_feed_varnames: [] # feed vars of save inference + save_inference_fetch_varnames: [] # fetch vars of save inference + init_model_path: "" # load model path + print_interval: 1 + phases: [phase1] +``` +### 单机预测 +我们通过词类比(Word Analogy)任务来检验word2vec模型的训练效果。输入四个词A,B,C,D,假设存在一种关系relation, 使得relation(A, B) = relation(C, D),然后通过A,B,C去预测D,emb(D) = emb(B) - emb(A) + emb(C)。 + +CPU环境 + +PaddleRec预测配置: + +在config.yaml文件中设置好epochs、device等参数。 + +``` +- name: single_cpu_infer + class: infer + # device to run training or infer + device: cpu + init_model_path: "increment_w2v" # load model path + print_interval: 1 + phases: [phase2] +``` + +为复现论文效果,我们提供了一个自定义预测脚本,在自定义预测中,我们会跳过预测结果是输入A,B,C的情况,然后计算预测准确率。执行命令如下: +``` +python infer.py --test_dir ./data/test --dict_path ./data/dict/word_id_dict.txt --batch_size 20000 --model_dir ./increment_w2v/ --start_index 0 --last_index 5 --emb_size 300 +``` + +### 运行 +``` +python -m paddlerec.run -m paddlerec.models.recall.word2vec +``` + +### 结果展示 + +样例数据训练结果展示: + +``` +Running SingleStartup. +Running SingleRunner. +W0813 11:36:16.129736 43843 build_strategy.cc:170] fusion_group is not enabled for Windows/MacOS now, and only effective when running with CUDA GPU. +batch: 1, LOSS: [3.618 3.684 3.698 3.653 3.736] +batch: 2, LOSS: [3.394 3.453 3.605 3.487 3.553] +batch: 3, LOSS: [3.411 3.402 3.444 3.387 3.357] +batch: 4, LOSS: [3.557 3.196 3.304 3.209 3.299] +batch: 5, LOSS: [3.217 3.141 3.168 3.114 3.315] +batch: 6, LOSS: [3.342 3.219 3.124 3.207 3.282] +batch: 7, LOSS: [3.19 3.207 3.136 3.322 3.164] +epoch 0 done, use time: 0.119026899338, global metrics: LOSS=[3.19 3.207 3.136 3.322 3.164] +... +epoch 4 done, use time: 0.097608089447, global metrics: LOSS=[2.734 2.66 2.763 2.804 2.809] +``` +样例数据预测结果展示: +``` +Running SingleInferStartup. +Running SingleInferRunner. +load persistables from increment_w2v/4 +batch: 1, acc: [1.] +batch: 2, acc: [1.] +batch: 3, acc: [1.] +Infer phase2 of epoch 4 done, use time: 4.89376211166, global metrics: acc=[1.] +... +Infer phase2 of epoch 3 done, use time: 4.43099021912, global metrics: acc=[1.] +``` + +## 论文复现 + +1. 用原论文的完整数据复现论文效果需要在config.yaml修改超参: +- name: dataset_train + batch_size: 100 # 1. 修改batch_size为100 + type: DataLoader + data_path: "{workspace}/data/all_train" # 2. 修改数据为全量训练数据 + word_count_dict_path: "{workspace}/data/all_dict/ word_count_dict.txt" # 3. 修改词表为全量词表 + data_converter: "{workspace}/w2v_reader.py" + +- name: single_cpu_train + - epochs: # 4. 修改config.yaml中runner的epochs为5。 + +修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行 +``` +python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径 +``` + +2. 使用自定义预测程序预测全量测试集: +``` +python infer.py --test_dir ./data/all_test --dict_path ./data/all_dict/word_id_dict.txt --batch_size 20000 --model_dir ./increment_w2v/ --start_index 0 --last_index 5 --emb_size 300 +``` + +结论:使用cpu训练5轮,自定义预测准确率为0.540,每轮训练时间7小时左右。 +## 进阶使用 + +## FAQ diff --git a/models/recall/word2vec/config.yaml b/models/recall/word2vec/config.yaml index 96f47221f26af19dc6cb505d34a27ca1295b4b50..7da9da0051ae5b6dab9c223c7064c76d5bd4fe5a 100755 --- a/models/recall/word2vec/config.yaml +++ b/models/recall/word2vec/config.yaml @@ -22,7 +22,7 @@ dataset: word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt" data_converter: "{workspace}/w2v_reader.py" - name: dataset_infer # name - batch_size: 50 + batch_size: 2000 type: DataLoader # or QueueDataset data_path: "{workspace}/data/test" word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt" @@ -42,38 +42,40 @@ hyper_parameters: window_size: 5 # select runner by name -mode: train_runner +mode: [single_cpu_train, single_cpu_infer] # config of each runner. # runner is a kind of paddle training class, which wraps the train/infer process. runner: -- name: train_runner +- name: single_cpu_train class: train # num of epochs - epochs: 2 + epochs: 5 # device to run training or infer device: cpu save_checkpoint_interval: 1 # save model interval of epochs save_inference_interval: 1 # save inference - save_checkpoint_path: "increment" # save checkpoint path - save_inference_path: "inference" # save inference path + save_checkpoint_path: "increment_w2v" # save checkpoint path + save_inference_path: "inference_w2v" # save inference path save_inference_feed_varnames: [] # feed vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference init_model_path: "" # load model path - print_interval: 1 -- name: infer_runner + print_interval: 1000 + phases: [phase1] +- name: single_cpu_infer class: infer # device to run training or infer device: cpu - init_model_path: "increment/0" # load model path + init_model_path: "increment_w2v" # load model path print_interval: 1 + phases: [phase2] # runner will run all the phase in each epoch phase: - name: phase1 model: "{workspace}/model.py" # user-defined model dataset_name: dataset_train # select dataset by name + thread_num: 5 +- name: phase2 + model: "{workspace}/model.py" # user-defined model + dataset_name: dataset_infer # select dataset by name thread_num: 1 -# - name: phase2 -# model: "{workspace}/model.py" # user-defined model -# dataset_name: dataset_infer # select dataset by name -# thread_num: 1 diff --git a/models/recall/word2vec/data_prepare.sh b/models/recall/word2vec/data_prepare.sh index eb1665240f58ae67afb942885932886e59e5a0a3..0cf8e3202fdda8541370e7726380c5a02a636a20 100644 --- a/models/recall/word2vec/data_prepare.sh +++ b/models/recall/word2vec/data_prepare.sh @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +mkdir -p data/all_dict +mkdir -p data/all_train +mkdir -p data/all_test # download train_data mkdir raw_data @@ -21,18 +24,16 @@ wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/1-billion-w tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ raw_data/ +# download test data +wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar +tar xzvf test_dir.tar -C raw_data +mv raw_data/data/test_dir/* data/all_test/ + # preprocess data python preprocess.py --build_dict --build_dict_corpus_dir raw_data/training-monolingual.tokenized.shuffled --dict_path raw_data/word_count_dict.txt python preprocess.py --filter_corpus --dict_path raw_data/word_count_dict.txt --input_corpus_dir raw_data/training-monolingual.tokenized.shuffled --output_corpus_dir raw_data/convert_text8 --min_count 5 --downsample 0.001 -mv raw_data/word_count_dict.txt data/dict/ -mv raw_data/word_id_dict.txt data/dict/ +python preprocess.py --data_resplit --input_corpus_dir=raw_data/convert_text8 --output_corpus_dir=data/all_train -rm -rf data/train/* -rm -rf data/test/* -python preprocess.py --data_resplit --input_corpus_dir=raw_data/convert_text8 --output_corpus_dir=data/train - -# download test data -wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar -tar xzvf test_dir.tar -C raw_data -mv raw_data/data/test_dir/* data/test/ +mv raw_data/word_count_dict.txt data/all_dict/ +mv raw_data/word_count_dict.txt_word_to_id_ data/all_dict/word_id_dict.txt rm -rf raw_data diff --git a/models/recall/word2vec/infer.py b/models/recall/word2vec/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..9593a70b5c4880c02bb5ea1ae70e9de4e45b36a7 --- /dev/null +++ b/models/recall/word2vec/infer.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys +import time +import math +import numpy as np +import six +import paddle.fluid as fluid +import paddle +import utils +if six.PY2: + reload(sys) + sys.setdefaultencoding('utf-8') + + +def parse_args(): + parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example") + parser.add_argument( + '--dict_path', + type=str, + default='./data/data_c/1-billion_dict_word_to_id_', + help="The path of dic") + parser.add_argument( + '--test_dir', type=str, default='test_data', help='test file address') + parser.add_argument( + '--print_step', type=int, default='500000', help='print step') + parser.add_argument( + '--start_index', type=int, default='0', help='start index') + parser.add_argument( + '--last_index', type=int, default='100', help='last index') + parser.add_argument( + '--model_dir', type=str, default='model', help='model dir') + parser.add_argument( + '--use_cuda', type=int, default='0', help='whether use cuda') + parser.add_argument( + '--batch_size', type=int, default='5', help='batch_size') + parser.add_argument( + '--emb_size', type=int, default='64', help='batch_size') + args = parser.parse_args() + return args + + +def infer_network(vocab_size, emb_size): + analogy_a = fluid.data(name="analogy_a", shape=[None], dtype='int64') + analogy_b = fluid.data(name="analogy_b", shape=[None], dtype='int64') + analogy_c = fluid.data(name="analogy_c", shape=[None], dtype='int64') + all_label = fluid.data(name="all_label", shape=[vocab_size], dtype='int64') + emb_all_label = fluid.embedding( + input=all_label, size=[vocab_size, emb_size], param_attr="emb") + + emb_a = fluid.embedding( + input=analogy_a, size=[vocab_size, emb_size], param_attr="emb") + emb_b = fluid.embedding( + input=analogy_b, size=[vocab_size, emb_size], param_attr="emb") + emb_c = fluid.embedding( + input=analogy_c, size=[vocab_size, emb_size], param_attr="emb") + target = fluid.layers.elementwise_add( + fluid.layers.elementwise_sub(emb_b, emb_a), emb_c) + emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) + dist = fluid.layers.matmul(x=target, y=emb_all_label_l2, transpose_y=True) + values, pred_idx = fluid.layers.topk(input=dist, k=4) + return values, pred_idx + + +def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w): + """ inference function """ + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + emb_size = args.emb_size + batch_size = args.batch_size + with fluid.scope_guard(fluid.Scope()): + main_program = fluid.Program() + with fluid.program_guard(main_program): + values, pred = infer_network(vocab_size, emb_size) + for epoch in range(start_index, last_index + 1): + copy_program = main_program.clone() + model_path = model_dir + "/" + str(epoch) + fluid.io.load_persistables( + exe, model_path, main_program=copy_program) + accum_num = 0 + accum_num_sum = 0.0 + t0 = time.time() + step_id = 0 + for data in test_reader(): + step_id += 1 + b_size = len([dat[0] for dat in data]) + wa = np.array([dat[0] for dat in data]).astype( + "int64").reshape(b_size) + wb = np.array([dat[1] for dat in data]).astype( + "int64").reshape(b_size) + wc = np.array([dat[2] for dat in data]).astype( + "int64").reshape(b_size) + + label = [dat[3] for dat in data] + input_word = [dat[4] for dat in data] + para = exe.run(copy_program, + feed={ + "analogy_a": wa, + "analogy_b": wb, + "analogy_c": wc, + "all_label": np.arange(vocab_size) + .reshape(vocab_size).astype("int64"), + }, + fetch_list=[pred.name, values], + return_numpy=False) + pre = np.array(para[0]) + val = np.array(para[1]) + for ii in range(len(label)): + top4 = pre[ii] + accum_num_sum += 1 + for idx in top4: + if int(idx) in input_word[ii]: + continue + if int(idx) == int(label[ii][0]): + accum_num += 1 + break + if step_id % 1 == 0: + print("step:%d %d " % (step_id, accum_num)) + + print("epoch:%d \t acc:%.3f " % + (epoch, 1.0 * accum_num / accum_num_sum)) + + +if __name__ == "__main__": + args = parse_args() + start_index = args.start_index + last_index = args.last_index + test_dir = args.test_dir + model_dir = args.model_dir + batch_size = args.batch_size + dict_path = args.dict_path + use_cuda = True if args.use_cuda else False + print("start index: ", start_index, " last_index:", last_index) + vocab_size, test_reader, id2word = utils.prepare_data( + test_dir, dict_path, batch_size=batch_size) + print("vocab_size:", vocab_size) + infer_epoch( + args, + vocab_size, + test_reader=test_reader, + use_cuda=use_cuda, + i2w=id2word) diff --git a/models/recall/word2vec/model.py b/models/recall/word2vec/model.py index 138bfd1a7dd4a70f454fab63c50835192f4ec107..8775e08bf094b22ff9070d5d7f6b08cab9a3d873 100755 --- a/models/recall/word2vec/model.py +++ b/models/recall/word2vec/model.py @@ -209,10 +209,10 @@ class Model(ModelBase): emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) dist = fluid.layers.matmul( x=target, y=emb_all_label_l2, transpose_y=True) - values, pred_idx = fluid.layers.topk(input=dist, k=4) + values, pred_idx = fluid.layers.topk(input=dist, k=1) label = fluid.layers.expand( fluid.layers.unsqueeze( - inputs[3], axes=[1]), expand_times=[1, 4]) + inputs[3], axes=[1]), expand_times=[1, 1]) label_ones = fluid.layers.fill_constant_batch_size_like( label, shape=[-1, 1], value=1.0, dtype='float32') right_cnt = fluid.layers.reduce_sum(input=fluid.layers.cast( diff --git a/models/recall/word2vec/preprocess.py b/models/recall/word2vec/preprocess.py index 679458e388229c5e3b3a8f3d14d212eb1d263544..c55d22dae8437b444733778a0de11aedf4aa5e46 100755 --- a/models/recall/word2vec/preprocess.py +++ b/models/recall/word2vec/preprocess.py @@ -228,7 +228,7 @@ def data_split(args): contents.extend(f.readlines()) num = int(args.file_nums) - lines_per_file = len(contents) / num + lines_per_file = int(math.ceil(len(contents) / float(num))) print("contents: ", str(len(contents))) print("lines_per_file: ", str(lines_per_file)) diff --git a/models/recall/word2vec/utils.py b/models/recall/word2vec/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c13a3e0ed5d6a3ccd09432a0f350192847a1dcf5 --- /dev/null +++ b/models/recall/word2vec/utils.py @@ -0,0 +1,131 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import collections +import six +import time +import numpy as np +import paddle.fluid as fluid +import paddle +import os +import preprocess +import io + + +def BuildWord_IdMap(dict_path): + word_to_id = dict() + id_to_word = dict() + with io.open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word_to_id[line.split(' ')[0]] = int(line.split(' ')[1]) + id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] + return word_to_id, id_to_word + + +def prepare_data(file_dir, dict_path, batch_size): + w2i, i2w = BuildWord_IdMap(dict_path) + vocab_size = len(i2w) + reader = fluid.io.batch(test(file_dir, w2i), batch_size) + return vocab_size, reader, i2w + + +def check_version(with_shuffle_batch=False): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + if with_shuffle_batch: + fluid.require_version('1.7.0') + else: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) + + +def native_to_unicode(s): + if _is_unicode(s): + return s + try: + return _to_unicode(s) + except UnicodeDecodeError: + res = _to_unicode(s, ignore_errors=True) + return res + + +def _is_unicode(s): + if six.PY2: + if isinstance(s, unicode): + return True + else: + if isinstance(s, str): + return True + return False + + +def _to_unicode(s, ignore_errors=False): + if _is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) + + +def strip_lines(line, vocab): + return _replace_oov(vocab, native_to_unicode(line)) + + +def _replace_oov(original_vocab, line): + """Replace out-of-vocab words with "". + This maintains compatibility with published results. + Args: + original_vocab: a set of strings (The standard vocabulary for the dataset) + line: a unicode string - a space-delimited sequence of words. + Returns: + a unicode string - a space-delimited sequence of words. + """ + return u" ".join([ + word if word in original_vocab else u"" for word in line.split() + ]) + + +def reader_creator(file_dir, word_to_id): + def reader(): + files = os.listdir(file_dir) + for fi in files: + with io.open( + os.path.join(file_dir, fi), "r", encoding='utf-8') as f: + for line in f: + if ':' in line: + pass + else: + line = strip_lines(line.lower(), word_to_id) + line = line.split() + yield [word_to_id[line[0]]], [word_to_id[line[1]]], [ + word_to_id[line[2]] + ], [word_to_id[line[3]]], [ + word_to_id[line[0]], word_to_id[line[1]], + word_to_id[line[2]] + ] + + return reader + + +def test(test_dir, w2i): + return reader_creator(test_dir, w2i) diff --git a/models/recall/word2vec/w2v_evaluate_reader.py b/models/recall/word2vec/w2v_evaluate_reader.py index 18571e475872b2805b40f9c7d25b471fec018ad7..40d0c66385d77e1fcdba91ca970359439319c6a3 100755 --- a/models/recall/word2vec/w2v_evaluate_reader.py +++ b/models/recall/word2vec/w2v_evaluate_reader.py @@ -76,7 +76,7 @@ class Reader(ReaderBase): def generate_sample(self, line): def reader(): if ':' in line: - pass + return features = self.strip_lines(line.lower(), self.word_to_id) features = features.split() yield [('analogy_a', [self.word_to_id[features[0]]]), diff --git a/models/recall/word2vec/w2v_reader.py b/models/recall/word2vec/w2v_reader.py index 768f380337059dbfaf793e5a8b7553703337706c..5c4e45b62c3fe11f9800ea8286264ace927c13b9 100755 --- a/models/recall/word2vec/w2v_reader.py +++ b/models/recall/word2vec/w2v_reader.py @@ -15,6 +15,7 @@ import io import numpy as np +import paddle.fluid as fluid from paddlerec.core.reader import ReaderBase from paddlerec.core.utils import envs @@ -47,6 +48,10 @@ class Reader(ReaderBase): self.with_shuffle_batch = envs.get_global_env( "hyper_parameters.with_shuffle_batch") self.random_generator = NumpyRandomInt(1, self.window_size + 1) + self.batch_size = envs.get_global_env( + "dataset.dataset_train.batch_size") + self.is_dataloader = envs.get_global_env( + "dataset.dataset_train.type") == "DataLoader" self.cs = None if not self.with_shuffle_batch: @@ -88,11 +93,46 @@ class Reader(ReaderBase): for context_id in context_word_ids: output = [('input_word', [int(target_id)]), ('true_label', [int(context_id)])] - if not self.with_shuffle_batch: + if self.with_shuffle_batch or self.is_dataloader: + yield output + else: neg_array = self.cs.searchsorted( np.random.sample(self.neg_num)) output += [('neg_label', [int(str(i)) for i in neg_array])] - yield output + yield output return reader + + def batch_tensor_creator(self, sample_reader): + def __reader__(): + result = [[], []] + for sample in sample_reader(): + for i, fea in enumerate(sample): + result[i].append(fea) + if len(result[0]) == self.batch_size: + tensor_result = [] + for tensor in result: + t = fluid.Tensor() + dat = np.array(tensor, dtype='int64') + if len(dat.shape) > 2: + dat = dat.reshape((dat.shape[0], dat.shape[2])) + elif len(dat.shape) == 1: + dat = dat.reshape((-1, 1)) + t.set(dat, fluid.CPUPlace()) + tensor_result.append(t) + if self.with_shuffle_batch: + yield tensor_result + else: + tt = fluid.Tensor() + neg_array = self.cs.searchsorted( + np.random.sample(self.neg_num)) + neg_array = np.tile(neg_array, self.batch_size) + tt.set( + neg_array.reshape((self.batch_size, self.neg_num)), + fluid.CPUPlace()) + tensor_result.append(tt) + yield tensor_result + result = [[], []] + + return __reader__