diff --git a/core/trainers/single_infer.py b/core/trainers/single_infer.py index ee41832d6e5d2d789c37969678e85ebe2b44aaa3..d54e418c2a36d96f94ec39e53fc11c19e43d3f06 100755 --- a/core/trainers/single_infer.py +++ b/core/trainers/single_infer.py @@ -230,7 +230,7 @@ class SingleInfer(TranspileTrainer): fetch_alias = [] fetch_period = int( envs.get_global_env("runner." + self._runner_name + - ".fetch_period", 20)) + ".print_interval", 20)) metrics = model_class.get_infer_results() if metrics: fetch_vars = metrics.values() @@ -261,7 +261,7 @@ class SingleInfer(TranspileTrainer): metrics_format = [] fetch_period = int( envs.get_global_env("runner." + self._runner_name + - ".fetch_period", 20)) + ".print_interval", 20)) metrics_format.append("{}: {{}}".format("batch")) for name, var in metrics.items(): metrics_varnames.append(var.name) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index acd972ac06237f8c3ef6dec85cf0430afed18fe2..fa5331e618e7fd271fb425d4960466795b33be2b 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -220,7 +220,7 @@ class SingleTrainer(TranspileTrainer): fetch_alias = [] fetch_period = int( envs.get_global_env("runner." + self._runner_name + - ".fetch_period", 20)) + ".print_interval", 20)) metrics = model_class.get_metrics() if metrics: fetch_vars = metrics.values() @@ -247,7 +247,7 @@ class SingleTrainer(TranspileTrainer): fetch_alias = [] fetch_period = int( envs.get_global_env("runner." + self._runner_name + - ".fetch_period", 20)) + ".print_interval", 20)) metrics = model_class.get_metrics() if metrics: fetch_vars = metrics.values() diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index 57bb81d56721ff875c8a9a747ed56ac100582ec2..fd64935dd2080291fc13911befc0481604c3464a 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -62,7 +62,7 @@ runner: save_inference_feed_varnames: [] # feed vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference init_model_path: "" # load model path - fetch_period: 10 + print_interval: 10 - name: runner2 class: single_infer # num of epochs diff --git a/models/recall/fasttext/__init__.py b/models/recall/fasttext/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/models/recall/fasttext/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/models/recall/fasttext/config.yaml b/models/recall/fasttext/config.yaml new file mode 100755 index 0000000000000000000000000000000000000000..24c10bcbe8a49de1eaa7fdedab309f8d865e345a --- /dev/null +++ b/models/recall/fasttext/config.yaml @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +workspace: "paddlerec.models.recall.fasttext" + +# list of dataset +dataset: +- name: dataset_train # name of dataset to distinguish different datasets + batch_size: 100 + type: DataLoader # or QueueDataset + data_path: "{workspace}/data/train" + word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt" + word_ngrams_path: "/home/malin10/code/paddlerec/models/recall/fasttext/word_ngrams_id" + data_converter: "{workspace}/reader.py" +- name: dataset_infer # name + batch_size: 50 + type: DataLoader # or QueueDataset + data_path: "{workspace}/data/test" + word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt" + data_converter: "{workspace}/evaluate_reader.py" + +hyper_parameters: + optimizer: + learning_rate: 1.0 + decay_steps: 100000 + decay_rate: 0.999 + class: sgd + strategy: async + sparse_feature_number: 224093 + sparse_feature_dim: 300 + with_shuffle_batch: False + neg_num: 5 + window_size: 5 + min_n: 3 + max_n: 5 + +# select runner by name +mode: runner1 +# config of each runner. +# runner is a kind of paddle training class, which wraps the train/infer process. +runner: +- name: runner1 + class: single_train + # num of epochs + epochs: 2 + # device to run training or infer + device: cpu + save_checkpoint_interval: 1 # save model interval of epochs + save_inference_interval: 1 # save inference + save_checkpoint_path: "increment" # save checkpoint path + save_inference_path: "inference" # save inference path + save_inference_feed_varnames: [] # feed vars of save inference + save_inference_fetch_varnames: [] # fetch vars of save inference + init_model_path: "" # load model path + fetch_period: 10 +- name: runner2 + class: single_infer + # num of epochs + epochs: 1 + # device to run training or infer + device: cpu + init_model_path: "increment/0" # load model path + +# runner will run all the phase in each epoch +phase: +- name: phase1 + model: "{workspace}/model.py" # user-defined model + dataset_name: dataset_train # select dataset by name + thread_num: 1 +#- name: phase2 +# model: "{workspace}/model.py" # user-defined model +# dataset_name: dataset_infer # select dataset by name +# thread_num: 1 diff --git a/models/recall/fasttext/evaluate_reader.py b/models/recall/fasttext/evaluate_reader.py new file mode 100755 index 0000000000000000000000000000000000000000..d4357ab6c66d5f1daf8738d2e6a645ee270e86d5 --- /dev/null +++ b/models/recall/fasttext/evaluate_reader.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io + +import six + +from paddlerec.core.reader import Reader +from paddlerec.core.utils import envs + + +class TrainReader(Reader): + def init(self): + dict_path = envs.get_global_env("dataset.dataset_infer.word_id_dict_path") + self.min_n = envs.get_global_env("hyper_parameters.min_n") + self.max_n = envs.get_global_env("hyper_parameters.max_n") + self.word_to_id = dict() + self.id_to_word = dict() + with io.open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + self.word_to_id[line.split(' ')[0]] = int(line.split(' ')[1]) + self.id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] + self.dict_size = len(self.word_to_id) + + def computeSubwords(self, word): + ngrams = set() + for i in range(len(word) - self.min_n + 1): + for j in range(self.min_n, self.max_n + 1): + end = min(len(word), i + j) + ngrams.add("".join(word[i:end])) + return list(ngrams) + + def native_to_unicode(self, s): + if self._is_unicode(s): + return s + try: + return self._to_unicode(s) + except UnicodeDecodeError: + res = self._to_unicode(s, ignore_errors=True) + return res + + def _is_unicode(self, s): + if six.PY2: + if isinstance(s, unicode): + return True + else: + if isinstance(s, str): + return True + return False + + def _to_unicode(self, s, ignore_errors=False): + if self._is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) + + def strip_lines(self, line, vocab): + return self._replace_oov(vocab, self.native_to_unicode(line)) + + def _replace_oov(self, original_vocab, line): + """Replace out-of-vocab words with "". + This maintains compatibility with published results. + Args: + original_vocab: a set of strings (The standard vocabulary for the dataset) + line: a unicode string - a space-delimited sequence of words. + Returns: + a unicode string - a space-delimited sequence of words. + """ + return u" ".join([ + "<" + word + ">" if "<" + word + ">" in original_vocab else u"" + for word in line.split() + ]) + + def generate_sample(self, line): + def reader(): + features = self.strip_lines(line.lower(), self.word_to_id) + features = features.split() + print(features) + inputs = [] + for item in features: + if item == "": + inputs.append([self.word_to_id[item]]) + else: + ngrams = self.computeSubwords(item) + res = [] + res.append(self.word_to_id[item]) + for _ in ngrams: + res.append(self.word_to_id[_]) + inputs.append(res) + print(inputs) + yield [('analogy_a', inputs[0]), + ('analogy_b', inputs[1]), + ('analogy_c', inputs[2]), + ('analogy_d', inputs[3][0:1])] + + return reader diff --git a/models/recall/fasttext/model.py b/models/recall/fasttext/model.py new file mode 100755 index 0000000000000000000000000000000000000000..6c9fe3c1ca9cd4b13ed9cad3332ea4a4b7523fe8 --- /dev/null +++ b/models/recall/fasttext/model.py @@ -0,0 +1,201 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid + +from paddlerec.core.utils import envs +from paddlerec.core.model import Model as ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def _init_hyper_parameters(self): + self.is_distributed = True if envs.get_trainer() == "CtrTrainer" else False + self.sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number") + self.sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim") + self.neg_num = envs.get_global_env("hyper_parameters.neg_num") + self.with_shuffle_batch = envs.get_global_env("hyper_parameters.with_shuffle_batch") + self.learning_rate = envs.get_global_env("hyper_parameters.optimizer.learning_rate") + self.decay_steps = envs.get_global_env("hyper_parameters.optimizer.decay_steps") + self.decay_rate = envs.get_global_env("hyper_parameters.optimizer.decay_rate") + + + def input_data(self, is_infer=False, **kwargs): + if is_infer: + analogy_a = fluid.data( + name="analogy_a", shape=[None, 1], lod_level=1, dtype='int64') + analogy_b = fluid.data( + name="analogy_b", shape=[None, 1], lod_level=1, dtype='int64') + analogy_c = fluid.data( + name="analogy_c", shape=[None, 1], lod_level=1, dtype='int64') + analogy_d = fluid.data( + name="analogy_d", shape=[None, 1], dtype='int64') + return [analogy_a, analogy_b, analogy_c, analogy_d] + + input_word = fluid.data( + name="input_word", shape=[None, 1], lod_level=1, dtype='int64') + true_word = fluid.data( + name='true_label', shape=[None, 1], lod_level=1, dtype='int64') + if self.with_shuffle_batch: + return [input_word, true_word] + + neg_word = fluid.data( + name="neg_label", shape=[None, self.neg_num], dtype='int64') + return [input_word, true_word, neg_word] + + def net(self, inputs, is_infer=False): + if is_infer: + self.infer_net(inputs) + return + + def embedding_layer(input, + table_name, + initializer_instance=None, + sequence_pool=False): + emb = fluid.embedding( + input=input, + is_sparse=True, + is_distributed=self.is_distributed, + size=[self.sparse_feature_number, self.sparse_feature_dim], + param_attr=fluid.ParamAttr( + name=table_name, initializer=initializer_instance), ) + if sequence_pool: + emb = fluid.layers.sequence_pool(input=emb, pool_type='average') + return emb + + init_width = 1.0 / self.sparse_feature_dim + emb_initializer = fluid.initializer.Uniform(-init_width, init_width) + emb_w_initializer = fluid.initializer.Constant(value=0.0) + + input_emb = embedding_layer(inputs[0], "emb", emb_initializer, True) + input_emb = fluid.layers.squeeze(input=input_emb, axes=[1]) + true_emb_w = embedding_layer(inputs[1], "emb_w", emb_w_initializer, True) + true_emb_w = fluid.layers.squeeze(input=true_emb_w, axes=[1]) + + + if self.with_shuffle_batch: + neg_emb_w_list = [] + for i in range(self.neg_num): + neg_emb_w_list.append( + fluid.contrib.layers.shuffle_batch( + true_emb_w)) # shuffle true_word + neg_emb_w_concat = fluid.layers.concat(neg_emb_w_list, axis=0) + neg_emb_w = fluid.layers.reshape( + neg_emb_w_concat, shape=[-1, self.neg_num, self.sparse_feature_dim]) + else: + neg_emb_w = embedding_layer(inputs[2], "emb_w", emb_w_initializer) + true_logits = fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(input_emb, true_emb_w), + dim=1, + keep_dim=True) + + input_emb_re = fluid.layers.reshape( + input_emb, shape=[-1, 1, self.sparse_feature_dim]) + neg_matmul = fluid.layers.matmul( + input_emb_re, neg_emb_w, transpose_y=True) + neg_logits = fluid.layers.reshape( + neg_matmul, shape=[-1, 1]) + + logits = fluid.layers.concat([true_logits, neg_logits], axis=0) + label_ones = fluid.layers.fill_constant( + shape=[fluid.layers.shape(true_logits)[0], 1], + value=1.0, + dtype='float32') + label_zeros = fluid.layers.fill_constant( + shape=[fluid.layers.shape(neg_logits)[0], 1], + value=0.0, + dtype='float32') + label = fluid.layers.concat([label_ones, label_zeros], axis=0) + + loss = fluid.layers.log_loss(fluid.layers.sigmoid(logits), label) + avg_cost = fluid.layers.reduce_sum(loss) + self._cost = avg_cost + self._metrics["LOSS"] = avg_cost + + + def optimizer(self): + optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.exponential_decay( + learning_rate=self.learning_rate, + decay_steps=self.decay_steps, + decay_rate=self.decay_rate, + staircase=True)) + return optimizer + + def infer_net(self, inputs): + def embedding_layer(input, table_name, initializer_instance=None, sequence_pool=False): + emb = fluid.embedding( + input=input, + size=[self.sparse_feature_number, self.sparse_feature_dim], + param_attr=table_name) + if sequence_pool: + emb = fluid.layers.sequence_pool(input=emb, pool_type='average') + return emb + + all_label = np.arange(self.sparse_feature_number).reshape( + self.sparse_feature_number).astype('int32') + self.all_label = fluid.layers.cast( + x=fluid.layers.assign(all_label), dtype='int64') + emb_all_label = embedding_layer(self.all_label, "emb") + fluid.layers.Print(inputs[0]) + fluid.layers.Print(inputs[1]) + fluid.layers.Print(inputs[2]) + fluid.layers.Print(inputs[3]) + emb_a = embedding_layer(inputs[0], "emb", sequence_pool=True) + emb_b = embedding_layer(inputs[1], "emb", sequence_pool=True) + emb_c = embedding_layer(inputs[2], "emb", sequence_pool=True) + + target = fluid.layers.elementwise_add( + fluid.layers.elementwise_sub(emb_b, emb_a), emb_c) + + emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) + dist = fluid.layers.matmul( + x=target, y=emb_all_label_l2, transpose_y=True) + values, pred_idx = fluid.layers.topk(input=dist, k=4) + label = fluid.layers.expand( + inputs[3], + expand_times=[1, 4]) + label_ones = fluid.layers.fill_constant_batch_size_like( + label, shape=[-1, 1], value=1.0, dtype='float32') + right_cnt = fluid.layers.reduce_sum(input=fluid.layers.cast( + fluid.layers.equal(pred_idx, label), dtype='float32')) + total_cnt = fluid.layers.reduce_sum(label_ones) + + # global_right_cnt = fluid.layers.create_global_var( + # name="global_right_cnt", + # persistable=True, + # dtype='float32', + # shape=[1], + # value=0) + # global_total_cnt = fluid.layers.create_global_var( + # name="global_total_cnt", + # persistable=True, + # dtype='float32', + # shape=[1], + # value=0) + # global_right_cnt.stop_gradient = True + # global_total_cnt.stop_gradient = True + + # tmp1 = fluid.layers.elementwise_add(right_cnt, global_right_cnt) + # fluid.layers.assign(tmp1, global_right_cnt) + # tmp2 = fluid.layers.elementwise_add(total_cnt, global_total_cnt) + # fluid.layers.assign(tmp2, global_total_cnt) + + # acc = fluid.layers.elementwise_div( + # global_right_cnt, global_total_cnt, name="total_acc") + acc = fluid.layers.elementwise_div(right_cnt, total_cnt, name="acc") + self._infer_results['acc'] = acc diff --git a/models/recall/fasttext/preprocess.py b/models/recall/fasttext/preprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..43d809718a9c18208bf47f1d2c0a871194503cd7 --- /dev/null +++ b/models/recall/fasttext/preprocess.py @@ -0,0 +1,306 @@ +# -*- coding: utf-8 -* +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import math +import os +import random +import re +import six + +import argparse + +prog = re.compile("[^a-z ]", flags=0) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Paddle Fluid word2 vector preprocess") + parser.add_argument( + '--build_dict_corpus_dir', type=str, help="The dir of corpus") + parser.add_argument( + '--input_corpus_dir', type=str, help="The dir of input corpus") + parser.add_argument( + '--output_corpus_dir', type=str, help="The dir of output corpus") + parser.add_argument( + '--dict_path', + type=str, + default='./dict', + help="The path of dictionary ") + parser.add_argument( + '--min_count', + type=int, + default=5, + help="If the word count is less then min_count, it will be removed from dict" + ) + parser.add_argument( + '--min_n', + type=int, + default=3, + help="min_n of ngrams" + ) + parser.add_argument( + '--max_n', + type=int, + default=5, + help="max_n of ngrams" + ) + parser.add_argument( + '--file_nums', + type=int, + default=1024, + help="re-split input corpus file nums") + parser.add_argument( + '--downsample', + type=float, + default=0.001, + help="filter word by downsample") + parser.add_argument( + '--filter_corpus', + action='store_true', + default=False, + help='Filter corpus') + parser.add_argument( + '--build_dict', + action='store_true', + default=False, + help='Build dict from corpus') + parser.add_argument( + '--data_resplit', + action='store_true', + default=False, + help='re-split input corpus files') + return parser.parse_args() + + +def text_strip(text): + # English Preprocess Rule + return prog.sub("", text.lower()) + + +# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py +# Unicode utility functions that work with Python 2 and 3 +def native_to_unicode(s): + if _is_unicode(s): + return s + try: + return _to_unicode(s) + except UnicodeDecodeError: + res = _to_unicode(s, ignore_errors=True) + return res + + +def _is_unicode(s): + if six.PY2: + if isinstance(s, unicode): + return True + else: + if isinstance(s, str): + return True + return False + + +def _to_unicode(s, ignore_errors=False): + if _is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) + + +def filter_corpus(args): + """ + filter corpus and convert id. + """ + word_count = dict() + word_to_id_ = dict() + word_all_count = 0 + id_counts = [] + word_id = 0 + # read dict + with io.open(args.dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, count = line.split()[0], int(line.split()[1]) + word_count[word] = count + word_to_id_[word] = word_id + word_id += 1 + id_counts.append(count) + word_all_count += count + + word_ngrams = dict() + with io.open("word_ngrams", 'r', encoding='utf-8') as f: + for line in f: + word, ngrams = line.rstrip().split(':') + ngrams = ngrams.split() + ngrams = [str(word_to_id_[_]) for _ in ngrams] + word_ngrams[word_to_id_[word]] = ' '.join(ngrams) + + with io.open("word_ngrams_id", 'w+', encoding='utf-8') as fid: + for k, v in word_ngrams.items(): + fid.write(u'{} {}\n'.format(k, v)) + + # write word2id file + print("write word2id file to : " + args.dict_path + "_word_to_id_") + with io.open( + args.dict_path + "_word_to_id_", 'w+', encoding='utf-8') as fid: + for k, v in word_to_id_.items(): + fid.write(k + " " + str(v) + '\n') + # filter corpus and convert id + if not os.path.exists(args.output_corpus_dir): + os.makedirs(args.output_corpus_dir) + for file in os.listdir(args.input_corpus_dir): + with io.open(args.output_corpus_dir + '/convert_' + file + '.csv', + "w") as wf: + with io.open( + args.input_corpus_dir + '/' + file, + encoding='utf-8') as rf: + print(args.input_corpus_dir + '/' + file) + for line in rf: + signal = False + line = text_strip(line) + words = line.split() + write_line = "" + for item in words: + if item in word_count: + idx = word_to_id_[item] + else: + idx = word_to_id_[native_to_unicode('')] + count_w = id_counts[idx] + corpus_size = word_all_count + keep_prob = ( + math.sqrt(count_w / + (args.downsample * corpus_size)) + 1 + ) * (args.downsample * corpus_size) / count_w + r_value = random.random() + if r_value > keep_prob: + continue + write_line += str(idx) + write_line += "," + signal = True + if signal: + write_line = write_line[:-1] + "\n" + wf.write(_to_unicode(write_line)) + + +def computeSubwords(word, min_n, max_n): + ngrams = set() + for i in range(len(word) - min_n + 1): + for j in range(min_n, max_n + 1): + end = min(len(word), i + j) + ngrams.add("".join(word[i:end])) + return list(ngrams) + +def build_dict(args): + """ + proprocess the data, generate dictionary and save into dict_path. + :param corpus_dir: the input data dir. + :param dict_path: the generated dict path. the data in dict is "word count" + :param min_count: + :return: + """ + # word to count + + word_count = dict() + + for file in os.listdir(args.build_dict_corpus_dir): + with io.open( + args.build_dict_corpus_dir + "/" + file, + encoding='utf-8') as f: + print("build dict : ", args.build_dict_corpus_dir + "/" + file) + for line in f: + line = text_strip(line) + words = line.split() + for item in words: + item = '<' + item + '>' + if item in word_count: + word_count[item] = word_count[item] + 1 + else: + word_count[item] = 1 + + item_to_remove = [] + for item in word_count: + if word_count[item] <= args.min_count: + item_to_remove.append(item) + + unk_sum = 0 + for item in item_to_remove: + unk_sum += word_count[item] + del word_count[item] + # sort by count + word_count[native_to_unicode('')] = unk_sum + + word_ngrams = dict() + ngrams_count = dict() + for item in word_count: + ngrams = computeSubwords(item, args.min_n, args.max_n) + word_ngrams[item] = ngrams + for sub_word in ngrams: + if sub_word not in ngrams_count: + ngrams_count[sub_word] = 1 + else: + ngrams_count[sub_word] = ngrams_count[sub_word] + 1 + ngrams_count = sorted( + ngrams_count.items(), key=lambda ngrams_count: -ngrams_count[1]) + + word_count = sorted( + word_count.items(), key=lambda word_count: -word_count[1]) + with io.open(args.dict_path, 'w+', encoding='utf-8') as f: + for k, v in word_count: + f.write(k + " " + str(v) + '\n') + for k, v in ngrams_count: + f.write(k + " " + str(v) + '\n') + + with io.open("word_ngrams", 'w+', encoding='utf-8') as f: + for key in word_ngrams: + f.write(key + ":") + f.write(" ".join(word_ngrams[key])) + f.write(u'\n') + +def data_split(args): + raw_data_dir = args.input_corpus_dir + new_data_dir = args.output_corpus_dir + if not os.path.exists(new_data_dir): + os.mkdir(new_data_dir) + files = os.listdir(raw_data_dir) + print(files) + index = 0 + contents = [] + for file_ in files: + with open(os.path.join(raw_data_dir, file_), 'r') as f: + contents.extend(f.readlines()) + + num = int(args.file_nums) + lines_per_file = len(contents) / num + print("contents: ", str(len(contents))) + print("lines_per_file: ", str(lines_per_file)) + + for i in range(1, num + 1): + with open(os.path.join(new_data_dir, "part_" + str(i)), 'w') as fout: + data = contents[(i - 1) * lines_per_file:min(i * lines_per_file, + len(contents))] + for line in data: + fout.write(line) + + +if __name__ == "__main__": + args = parse_args() + if args.build_dict: + build_dict(args) + elif args.filter_corpus: + filter_corpus(args) + elif args.data_resplit: + data_split(args) + else: + print( + "error command line, please choose --build_dict or --filter_corpus") diff --git a/models/recall/fasttext/reader.py b/models/recall/fasttext/reader.py new file mode 100755 index 0000000000000000000000000000000000000000..a804ef902d9bc17d1352db2fb44bd664b786dbaf --- /dev/null +++ b/models/recall/fasttext/reader.py @@ -0,0 +1,106 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io + +import numpy as np + +from paddlerec.core.reader import Reader +from paddlerec.core.utils import envs + + +class NumpyRandomInt(object): + def __init__(self, a, b, buf_size=1000): + self.idx = 0 + self.buffer = np.random.random_integers(a, b, buf_size) + self.a = a + self.b = b + + def __call__(self): + if self.idx == len(self.buffer): + self.buffer = np.random.random_integers(self.a, self.b, + len(self.buffer)) + self.idx = 0 + + result = self.buffer[self.idx] + self.idx += 1 + return result + + +class TrainReader(Reader): + def init(self): + dict_path = envs.get_global_env("dataset.dataset_train.word_count_dict_path") + word_ngrams_path = envs.get_global_env("dataset.dataset_train.word_ngrams_path") + self.window_size = envs.get_global_env("hyper_parameters.window_size") + self.neg_num = envs.get_global_env("hyper_parameters.neg_num") + self.with_shuffle_batch = envs.get_global_env( + "hyper_parameters.with_shuffle_batch") + self.random_generator = NumpyRandomInt(1, self.window_size + 1) + + self.word_ngrams = dict() + with io.open(word_ngrams_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.rstrip().split() + self.word_ngrams[str(line[0])] = map(int, line[1:]) + + self.cs = None + if not self.with_shuffle_batch: + id_counts = [] + word_all_count = 0 + with io.open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, count = line.split()[0], int(line.split()[1]) + id_counts.append(count) + word_all_count += count + id_frequencys = [ + float(count) / word_all_count for count in id_counts + ] + np_power = np.power(np.array(id_frequencys), 0.75) + id_frequencys_pow = np_power / np_power.sum() + self.cs = np.array(id_frequencys_pow).cumsum() + + def get_context_words(self, words, idx): + """ + Get the context word list of target word. + words: the words of the current line + idx: input word index + window_size: window size + """ + target_window = self.random_generator() + start_point = idx - target_window # if (idx - target_window) > 0 else 0 + if start_point < 0: + start_point = 0 + end_point = idx + target_window + targets = words[start_point:idx] + words[idx + 1:end_point + 1] + return targets + + def generate_sample(self, line): + def reader(): + word_ids = [w for w in line.split()] + for idx, target_id in enumerate(word_ids): + input_word = [int(target_id)] + if target_id in self.word_ngrams: + input_word += self.word_ngrams[target_id] + context_word_ids = self.get_context_words(word_ids, idx) + for context_id in context_word_ids: + output = [('input_word', input_word), + ('true_label', [int(context_id)])] + if not self.with_shuffle_batch: + neg_array = self.cs.searchsorted( + np.random.sample(self.neg_num)) + output += [('neg_label', + [int(str(i)) for i in neg_array])] + yield output + + return reader