diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index 8059eeb09a482671b8329fb88f5b52cfd64f163b..d7f8016c2020074ebbe42d6178c69a226e16e2ca 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -68,6 +68,8 @@ class DataLoader(DatasetBase): reader_ins = SlotReader(context["config_yaml"]) if hasattr(reader_ins, 'generate_batch_from_trainfiles'): dataloader.set_sample_list_generator(reader) + elif hasattr(reader_ins, 'batch_tensor_creator'): + dataloader.set_batch_generator(reader) else: dataloader.set_sample_generator(reader, batch_size) return dataloader diff --git a/core/utils/dataloader_instance.py b/core/utils/dataloader_instance.py index 2461473aa79a51133db8aa319f4ee7d45981d815..15fbd3a3222d0a22d34cec4ca17ac726675feb29 100755 --- a/core/utils/dataloader_instance.py +++ b/core/utils/dataloader_instance.py @@ -67,6 +67,10 @@ def dataloader_by_name(readerclass, if hasattr(reader, 'generate_batch_from_trainfiles'): return gen_batch_reader() + + if hasattr(reader, "batch_tensor_creator"): + return reader.batch_tensor_creator(gen_reader) + return gen_reader diff --git a/models/recall/word2vec/README.md b/models/recall/word2vec/README.md index 241f236f932eed787bad2d617874531d8a6d563b..b91f8927e416499ce94c3393ca226c31908f234e 100644 --- a/models/recall/word2vec/README.md +++ b/models/recall/word2vec/README.md @@ -19,6 +19,8 @@ ├── data_prepare.sh #一键数据处理脚本 ├── w2v_reader.py #训练数据reader ├── w2v_evaluate_reader.py # 预测数据reader +├── infer.py # 自定义预测脚本 +├── utils.py # 自定义预测中用到的reader等工具 ``` 注:在阅读该示例前,建议您先了解以下内容: @@ -154,9 +156,12 @@ runner: phases: [phase1] ``` ### 单机预测 +我们通过词类比(Word Analogy)任务来检验word2vec模型的训练效果。输入四个词A,B,C,D,假设存在一种关系relation, 使得relation(A, B) = relation(C, D),然后通过A,B,C去预测D,emb(D) = emb(B) - emb(A) + emb(C)。 CPU环境 +PaddleRec预测配置: + 在config.yaml文件中设置好epochs、device等参数。 ``` @@ -168,6 +173,10 @@ CPU环境 print_interval: 1 phases: [phase2] ``` +为复现论文效果,我们提供了一个自定义预测脚本,自定义预测中,我们会跳过预测结果是输入A,B,C的情况,计算预测准确率。执行命令如下: +``` +python infer.py --test_dir ./data/test --dict_path ./data/dict/word_id_dict.txt --batch_size 20000 --model_dir ./increment_w2v/ --start_index 0 --last_index 5 --emb_size 300 +``` ### 运行 ``` @@ -212,13 +221,12 @@ Infer phase2 of epoch 3 done, use time: 4.43099021912, global metrics: acc=[1.] - batch_size: 修改config.yaml中dataset_train数据集的batch_size为100。 - epochs: 修改config.yaml中runner的epochs为5。 -使用cpu训练 5轮 测试Recall@20:0.540 - 修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行 ``` python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径 ``` +使用cpu训练5轮,自定义测试(跳过输入)准确率为0.540。 ## 进阶使用 ## FAQ diff --git a/models/recall/word2vec/config.yaml b/models/recall/word2vec/config.yaml index 3e8347a94748560e85734a4c5a68d4b529c29ba4..baa9e83527433741b7ef33b4d07db548fa3368cd 100755 --- a/models/recall/word2vec/config.yaml +++ b/models/recall/word2vec/config.yaml @@ -22,7 +22,7 @@ dataset: word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt" data_converter: "{workspace}/w2v_reader.py" - name: dataset_infer # name - batch_size: 50 + batch_size: 2000 type: DataLoader # or QueueDataset data_path: "{workspace}/data/test" word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt" @@ -59,7 +59,7 @@ runner: save_inference_feed_varnames: [] # feed vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference init_model_path: "" # load model path - print_interval: 1 + print_interval: 1000 phases: [phase1] - name: single_cpu_infer class: infer diff --git a/models/recall/word2vec/data_prepare.sh b/models/recall/word2vec/data_prepare.sh index eb1665240f58ae67afb942885932886e59e5a0a3..6ca2e7f2df8d10fcc5c84b42f7de7929b6b9d0da 100644 --- a/models/recall/word2vec/data_prepare.sh +++ b/models/recall/word2vec/data_prepare.sh @@ -25,7 +25,7 @@ mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tok python preprocess.py --build_dict --build_dict_corpus_dir raw_data/training-monolingual.tokenized.shuffled --dict_path raw_data/word_count_dict.txt python preprocess.py --filter_corpus --dict_path raw_data/word_count_dict.txt --input_corpus_dir raw_data/training-monolingual.tokenized.shuffled --output_corpus_dir raw_data/convert_text8 --min_count 5 --downsample 0.001 mv raw_data/word_count_dict.txt data/dict/ -mv raw_data/word_id_dict.txt data/dict/ +mv raw_data/word_count_dict.txt_word_to_id_ data/dict/word_id_dict.txt rm -rf data/train/* rm -rf data/test/* diff --git a/models/recall/word2vec/infer.py b/models/recall/word2vec/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..9593a70b5c4880c02bb5ea1ae70e9de4e45b36a7 --- /dev/null +++ b/models/recall/word2vec/infer.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys +import time +import math +import numpy as np +import six +import paddle.fluid as fluid +import paddle +import utils +if six.PY2: + reload(sys) + sys.setdefaultencoding('utf-8') + + +def parse_args(): + parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example") + parser.add_argument( + '--dict_path', + type=str, + default='./data/data_c/1-billion_dict_word_to_id_', + help="The path of dic") + parser.add_argument( + '--test_dir', type=str, default='test_data', help='test file address') + parser.add_argument( + '--print_step', type=int, default='500000', help='print step') + parser.add_argument( + '--start_index', type=int, default='0', help='start index') + parser.add_argument( + '--last_index', type=int, default='100', help='last index') + parser.add_argument( + '--model_dir', type=str, default='model', help='model dir') + parser.add_argument( + '--use_cuda', type=int, default='0', help='whether use cuda') + parser.add_argument( + '--batch_size', type=int, default='5', help='batch_size') + parser.add_argument( + '--emb_size', type=int, default='64', help='batch_size') + args = parser.parse_args() + return args + + +def infer_network(vocab_size, emb_size): + analogy_a = fluid.data(name="analogy_a", shape=[None], dtype='int64') + analogy_b = fluid.data(name="analogy_b", shape=[None], dtype='int64') + analogy_c = fluid.data(name="analogy_c", shape=[None], dtype='int64') + all_label = fluid.data(name="all_label", shape=[vocab_size], dtype='int64') + emb_all_label = fluid.embedding( + input=all_label, size=[vocab_size, emb_size], param_attr="emb") + + emb_a = fluid.embedding( + input=analogy_a, size=[vocab_size, emb_size], param_attr="emb") + emb_b = fluid.embedding( + input=analogy_b, size=[vocab_size, emb_size], param_attr="emb") + emb_c = fluid.embedding( + input=analogy_c, size=[vocab_size, emb_size], param_attr="emb") + target = fluid.layers.elementwise_add( + fluid.layers.elementwise_sub(emb_b, emb_a), emb_c) + emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) + dist = fluid.layers.matmul(x=target, y=emb_all_label_l2, transpose_y=True) + values, pred_idx = fluid.layers.topk(input=dist, k=4) + return values, pred_idx + + +def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w): + """ inference function """ + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + emb_size = args.emb_size + batch_size = args.batch_size + with fluid.scope_guard(fluid.Scope()): + main_program = fluid.Program() + with fluid.program_guard(main_program): + values, pred = infer_network(vocab_size, emb_size) + for epoch in range(start_index, last_index + 1): + copy_program = main_program.clone() + model_path = model_dir + "/" + str(epoch) + fluid.io.load_persistables( + exe, model_path, main_program=copy_program) + accum_num = 0 + accum_num_sum = 0.0 + t0 = time.time() + step_id = 0 + for data in test_reader(): + step_id += 1 + b_size = len([dat[0] for dat in data]) + wa = np.array([dat[0] for dat in data]).astype( + "int64").reshape(b_size) + wb = np.array([dat[1] for dat in data]).astype( + "int64").reshape(b_size) + wc = np.array([dat[2] for dat in data]).astype( + "int64").reshape(b_size) + + label = [dat[3] for dat in data] + input_word = [dat[4] for dat in data] + para = exe.run(copy_program, + feed={ + "analogy_a": wa, + "analogy_b": wb, + "analogy_c": wc, + "all_label": np.arange(vocab_size) + .reshape(vocab_size).astype("int64"), + }, + fetch_list=[pred.name, values], + return_numpy=False) + pre = np.array(para[0]) + val = np.array(para[1]) + for ii in range(len(label)): + top4 = pre[ii] + accum_num_sum += 1 + for idx in top4: + if int(idx) in input_word[ii]: + continue + if int(idx) == int(label[ii][0]): + accum_num += 1 + break + if step_id % 1 == 0: + print("step:%d %d " % (step_id, accum_num)) + + print("epoch:%d \t acc:%.3f " % + (epoch, 1.0 * accum_num / accum_num_sum)) + + +if __name__ == "__main__": + args = parse_args() + start_index = args.start_index + last_index = args.last_index + test_dir = args.test_dir + model_dir = args.model_dir + batch_size = args.batch_size + dict_path = args.dict_path + use_cuda = True if args.use_cuda else False + print("start index: ", start_index, " last_index:", last_index) + vocab_size, test_reader, id2word = utils.prepare_data( + test_dir, dict_path, batch_size=batch_size) + print("vocab_size:", vocab_size) + infer_epoch( + args, + vocab_size, + test_reader=test_reader, + use_cuda=use_cuda, + i2w=id2word) diff --git a/models/recall/word2vec/model.py b/models/recall/word2vec/model.py index 138bfd1a7dd4a70f454fab63c50835192f4ec107..4b19513d627722310ee4fb4d39fbdb9b80b327bf 100755 --- a/models/recall/word2vec/model.py +++ b/models/recall/word2vec/model.py @@ -209,10 +209,10 @@ class Model(ModelBase): emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) dist = fluid.layers.matmul( x=target, y=emb_all_label_l2, transpose_y=True) - values, pred_idx = fluid.layers.topk(input=dist, k=4) + values, pred_idx = fluid.layers.topk(input=dist, 1) label = fluid.layers.expand( fluid.layers.unsqueeze( - inputs[3], axes=[1]), expand_times=[1, 4]) + inputs[3], axes=[1]), expand_times=[1, 1]) label_ones = fluid.layers.fill_constant_batch_size_like( label, shape=[-1, 1], value=1.0, dtype='float32') right_cnt = fluid.layers.reduce_sum(input=fluid.layers.cast( diff --git a/models/recall/word2vec/utils.py b/models/recall/word2vec/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c13a3e0ed5d6a3ccd09432a0f350192847a1dcf5 --- /dev/null +++ b/models/recall/word2vec/utils.py @@ -0,0 +1,131 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import collections +import six +import time +import numpy as np +import paddle.fluid as fluid +import paddle +import os +import preprocess +import io + + +def BuildWord_IdMap(dict_path): + word_to_id = dict() + id_to_word = dict() + with io.open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word_to_id[line.split(' ')[0]] = int(line.split(' ')[1]) + id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] + return word_to_id, id_to_word + + +def prepare_data(file_dir, dict_path, batch_size): + w2i, i2w = BuildWord_IdMap(dict_path) + vocab_size = len(i2w) + reader = fluid.io.batch(test(file_dir, w2i), batch_size) + return vocab_size, reader, i2w + + +def check_version(with_shuffle_batch=False): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + if with_shuffle_batch: + fluid.require_version('1.7.0') + else: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) + + +def native_to_unicode(s): + if _is_unicode(s): + return s + try: + return _to_unicode(s) + except UnicodeDecodeError: + res = _to_unicode(s, ignore_errors=True) + return res + + +def _is_unicode(s): + if six.PY2: + if isinstance(s, unicode): + return True + else: + if isinstance(s, str): + return True + return False + + +def _to_unicode(s, ignore_errors=False): + if _is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) + + +def strip_lines(line, vocab): + return _replace_oov(vocab, native_to_unicode(line)) + + +def _replace_oov(original_vocab, line): + """Replace out-of-vocab words with "". + This maintains compatibility with published results. + Args: + original_vocab: a set of strings (The standard vocabulary for the dataset) + line: a unicode string - a space-delimited sequence of words. + Returns: + a unicode string - a space-delimited sequence of words. + """ + return u" ".join([ + word if word in original_vocab else u"" for word in line.split() + ]) + + +def reader_creator(file_dir, word_to_id): + def reader(): + files = os.listdir(file_dir) + for fi in files: + with io.open( + os.path.join(file_dir, fi), "r", encoding='utf-8') as f: + for line in f: + if ':' in line: + pass + else: + line = strip_lines(line.lower(), word_to_id) + line = line.split() + yield [word_to_id[line[0]]], [word_to_id[line[1]]], [ + word_to_id[line[2]] + ], [word_to_id[line[3]]], [ + word_to_id[line[0]], word_to_id[line[1]], + word_to_id[line[2]] + ] + + return reader + + +def test(test_dir, w2i): + return reader_creator(test_dir, w2i) diff --git a/models/recall/word2vec/w2v_evaluate_reader.py b/models/recall/word2vec/w2v_evaluate_reader.py index 18571e475872b2805b40f9c7d25b471fec018ad7..40d0c66385d77e1fcdba91ca970359439319c6a3 100755 --- a/models/recall/word2vec/w2v_evaluate_reader.py +++ b/models/recall/word2vec/w2v_evaluate_reader.py @@ -76,7 +76,7 @@ class Reader(ReaderBase): def generate_sample(self, line): def reader(): if ':' in line: - pass + return features = self.strip_lines(line.lower(), self.word_to_id) features = features.split() yield [('analogy_a', [self.word_to_id[features[0]]]), diff --git a/models/recall/word2vec/w2v_reader.py b/models/recall/word2vec/w2v_reader.py index 768f380337059dbfaf793e5a8b7553703337706c..5c4e45b62c3fe11f9800ea8286264ace927c13b9 100755 --- a/models/recall/word2vec/w2v_reader.py +++ b/models/recall/word2vec/w2v_reader.py @@ -15,6 +15,7 @@ import io import numpy as np +import paddle.fluid as fluid from paddlerec.core.reader import ReaderBase from paddlerec.core.utils import envs @@ -47,6 +48,10 @@ class Reader(ReaderBase): self.with_shuffle_batch = envs.get_global_env( "hyper_parameters.with_shuffle_batch") self.random_generator = NumpyRandomInt(1, self.window_size + 1) + self.batch_size = envs.get_global_env( + "dataset.dataset_train.batch_size") + self.is_dataloader = envs.get_global_env( + "dataset.dataset_train.type") == "DataLoader" self.cs = None if not self.with_shuffle_batch: @@ -88,11 +93,46 @@ class Reader(ReaderBase): for context_id in context_word_ids: output = [('input_word', [int(target_id)]), ('true_label', [int(context_id)])] - if not self.with_shuffle_batch: + if self.with_shuffle_batch or self.is_dataloader: + yield output + else: neg_array = self.cs.searchsorted( np.random.sample(self.neg_num)) output += [('neg_label', [int(str(i)) for i in neg_array])] - yield output + yield output return reader + + def batch_tensor_creator(self, sample_reader): + def __reader__(): + result = [[], []] + for sample in sample_reader(): + for i, fea in enumerate(sample): + result[i].append(fea) + if len(result[0]) == self.batch_size: + tensor_result = [] + for tensor in result: + t = fluid.Tensor() + dat = np.array(tensor, dtype='int64') + if len(dat.shape) > 2: + dat = dat.reshape((dat.shape[0], dat.shape[2])) + elif len(dat.shape) == 1: + dat = dat.reshape((-1, 1)) + t.set(dat, fluid.CPUPlace()) + tensor_result.append(t) + if self.with_shuffle_batch: + yield tensor_result + else: + tt = fluid.Tensor() + neg_array = self.cs.searchsorted( + np.random.sample(self.neg_num)) + neg_array = np.tile(neg_array, self.batch_size) + tt.set( + neg_array.reshape((self.batch_size, self.neg_num)), + fluid.CPUPlace()) + tensor_result.append(tt) + yield tensor_result + result = [[], []] + + return __reader__