From 795f74e3c7a35603e2af3198114bd9f9f0e74685 Mon Sep 17 00:00:00 2001 From: malin10 Date: Fri, 8 May 2020 14:35:57 +0800 Subject: [PATCH] add w2v --- models/recall/w2v_evaluate_reader.py | 80 +++++++ models/recall/w2v_reader.py | 90 ++++++++ models/recall/word2vec/config.yaml | 59 +++++ .../word2vec/data/dict/word_count_dict.txt | 85 ++++++++ .../word2vec/data/dict/word_id_dict.txt | 85 ++++++++ models/recall/word2vec/data/test/sample.txt | 200 +++++++++++++++++ .../word2vec/data/train/convert_sample.txt | 195 +++++++++++++++++ models/recall/word2vec/model.py | 202 ++++++++++++++++++ 8 files changed, 996 insertions(+) create mode 100755 models/recall/w2v_evaluate_reader.py create mode 100755 models/recall/w2v_reader.py create mode 100644 models/recall/word2vec/config.yaml create mode 100644 models/recall/word2vec/data/dict/word_count_dict.txt create mode 100644 models/recall/word2vec/data/dict/word_id_dict.txt create mode 100644 models/recall/word2vec/data/test/sample.txt create mode 100644 models/recall/word2vec/data/train/convert_sample.txt create mode 100644 models/recall/word2vec/model.py diff --git a/models/recall/w2v_evaluate_reader.py b/models/recall/w2v_evaluate_reader.py new file mode 100755 index 00000000..df6de931 --- /dev/null +++ b/models/recall/w2v_evaluate_reader.py @@ -0,0 +1,80 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import io +import six +from fleetrec.core.reader import Reader +from fleetrec.core.utils import envs + + +class EvaluateReader(Reader): + def init(self): + dict_path = envs.get_global_env("word_id_dict_path", None, "evaluate.reader") + self.word_to_id = dict() + self.id_to_word = dict() + with io.open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + self.word_to_id[line.split(' ')[0]] = int(line.split(' ')[1]) + self.id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] + self.dict_size = len(self.word_to_id) + + def native_to_unicode(self, s): + if self._is_unicode(s): + return s + try: + return self._to_unicode(s) + except UnicodeDecodeError: + res = self._to_unicode(s, ignore_errors=True) + return res + + def _is_unicode(self, s): + if six.PY2: + if isinstance(s, unicode): + return True + else: + if isinstance(s, str): + return True + return False + + + def _to_unicode(self, s, ignore_errors=False): + if self._is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) + + + def strip_lines(self, line, vocab): + return self._replace_oov(vocab, self.native_to_unicode(line)) + + + def _replace_oov(self, original_vocab, line): + """Replace out-of-vocab words with "". + This maintains compatibility with published results. + Args: + original_vocab: a set of strings (The standard vocabulary for the dataset) + line: a unicode string - a space-delimited sequence of words. + Returns: + a unicode string - a space-delimited sequence of words. + """ + return u" ".join([ + word if word in original_vocab else u"" for word in line.split() + ]) + + def generate_sample(self, line): + def reader(): + features = self.strip_lines(line.lower(), self.word_to_id) + features = features.split() + yield [('analogy_a', [self.word_to_id[features[0]]]), ('analogy_b', [self.word_to_id[features[1]]]), ('analogy_c', [self.word_to_id[features[2]]]), ('analogy_d', [self.word_to_id[features[3]]])] + return reader diff --git a/models/recall/w2v_reader.py b/models/recall/w2v_reader.py new file mode 100755 index 00000000..4857a6b0 --- /dev/null +++ b/models/recall/w2v_reader.py @@ -0,0 +1,90 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import io +from fleetrec.core.reader import Reader +from fleetrec.core.utils import envs + + +class NumpyRandomInt(object): + def __init__(self, a, b, buf_size=1000): + self.idx = 0 + self.buffer = np.random.random_integers(a, b, buf_size) + self.a = a + self.b = b + + def __call__(self): + if self.idx == len(self.buffer): + self.buffer = np.random.random_integers(self.a, self.b, + len(self.buffer)) + self.idx = 0 + + result = self.buffer[self.idx] + self.idx += 1 + return result + + +class TrainReader(Reader): + def init(self): + dict_path = envs.get_global_env("word_count_dict_path", None, "train.reader") + self.window_size = envs.get_global_env("hyper_parameters.window_size", None, "train.model") + self.neg_num = envs.get_global_env("hyper_parameters.neg_num", None, "train.model") + self.with_shuffle_batch = envs.get_global_env("hyper_parameters.with_shuffle_batch", None, "train.model") + self.random_generator = NumpyRandomInt(1, self.window_size + 1) + + self.cs = None + if not self.with_shuffle_batch: + id_counts = [] + word_all_count = 0 + with io.open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, count = line.split()[0], int(line.split()[1]) + id_counts.append(count) + word_all_count += count + id_frequencys = [ + float(count) / word_all_count for count in id_counts + ] + np_power = np.power(np.array(id_frequencys), 0.75) + id_frequencys_pow = np_power / np_power.sum() + self.cs = np.array(id_frequencys_pow).cumsum() + + def get_context_words(self, words, idx): + """ + Get the context word list of target word. + words: the words of the current line + idx: input word index + window_size: window size + """ + target_window = self.random_generator() + start_point = idx - target_window # if (idx - target_window) > 0 else 0 + if start_point < 0: + start_point = 0 + end_point = idx + target_window + targets = words[start_point:idx] + words[idx + 1:end_point + 1] + return targets + + def generate_sample(self, line): + def reader(): + word_ids = [w for w in line.split()] + for idx, target_id in enumerate(word_ids): + context_word_ids = self.get_context_words( + word_ids, idx) + for context_id in context_word_ids: + output = [('input_word', [int(target_id)]), ('true_label', [int(context_id)])] + if not self.with_shuffle_batch: + neg_array = self.cs.searchsorted(np.random.sample(self.neg_num)) + output += [('neg_label', [int(str(i)) for i in neg_array ])] + yield output + return reader + diff --git a/models/recall/word2vec/config.yaml b/models/recall/word2vec/config.yaml new file mode 100644 index 00000000..af017b81 --- /dev/null +++ b/models/recall/word2vec/config.yaml @@ -0,0 +1,59 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +evaluate: + workspace: "fleetrec.models.recall.word2vec" + reader: + batch_size: 50 + class: "{workspace}/../w2v_evaluate_reader.py" + test_data_path: "{workspace}/data/test" + word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt" + +train: + trainer: + # for cluster training + strategy: "async" + + epochs: 2 + workspace: "fleetrec.models.recall.word2vec" + + reader: + batch_size: 100 + class: "{workspace}/../w2v_reader.py" + train_data_path: "{workspace}/data/train" + test_data_path: "{workspace}/data/test" + word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt" + word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt" + + model: + models: "{workspace}/model.py" + hyper_parameters: + sparse_feature_number: 354051 + sparse_feature_dim: 300 + with_shuffle_batch: False + neg_num: 5 + window_size: 5 + learning_rate: 1.0 + decay_steps: 100000 + decay_rate: 0.999 + optimizer: sgd + + save: + increment: + dirname: "increment" + epoch_interval: 1 + save_last: True + inference: + dirname: "inference" + epoch_interval: 1 + save_last: True diff --git a/models/recall/word2vec/data/dict/word_count_dict.txt b/models/recall/word2vec/data/dict/word_count_dict.txt new file mode 100644 index 00000000..4add75bc --- /dev/null +++ b/models/recall/word2vec/data/dict/word_count_dict.txt @@ -0,0 +1,85 @@ + 2541 +the 256 +to 135 +of 122 +a 106 +in 97 +and 94 +that 54 +for 49 +is 47 +on 44 +s 43 +at 37 +said 34 +be 31 +with 27 +will 26 +are 25 +have 24 +was 23 +it 22 +more 20 +who 20 +an 19 +as 19 +by 18 +his 18 +from 18 +they 17 +not 16 +their 16 +has 15 +there 15 +this 15 +but 15 +we 13 +he 13 +been 12 +out 12 +new 11 +would 11 +than 11 +were 11 +year 10 +or 10 +us 10 +had 9 +first 9 +all 9 +two 9 +after 8 +them 8 +t 8 +most 8 +last 8 +some 8 +so 8 +i 8 +even 7 +when 7 +according 7 +its 7 +during 7 +per 7 +because 7 +up 7 +she 7 +home 7 +about 7 +mr 6 +do 6 +if 6 +just 6 +no 6 +time 6 +team 6 +may 6 +years 6 +city 6 +only 6 +world 6 +you 6 +including 6 +day 6 +cent 6 diff --git a/models/recall/word2vec/data/dict/word_id_dict.txt b/models/recall/word2vec/data/dict/word_id_dict.txt new file mode 100644 index 00000000..f8bd06de --- /dev/null +++ b/models/recall/word2vec/data/dict/word_id_dict.txt @@ -0,0 +1,85 @@ +and 6 +all 48 +because 64 +just 72 +per 63 +when 59 +is 9 +year 43 +some 55 +it 20 +an 23 +as 24 +including 82 +at 12 +have 18 +in 5 +home 67 +its 61 + 0 +even 58 +city 78 +said 13 +from 27 +for 8 +their 30 +there 32 +had 46 +two 49 +been 37 +than 41 +up 65 +to 2 +only 79 +time 74 +new 39 +you 81 +has 31 +was 19 +day 83 +more 21 +be 14 +we 35 +his 26 +may 76 +do 70 +that 7 +mr 69 +she 66 +team 75 +who 22 +but 34 +if 71 +most 53 +cent 84 +them 51 +they 28 +not 29 +during 62 +years 77 +with 15 +by 25 +after 50 +he 36 +a 4 +on 10 +about 68 +last 54 +would 40 +world 80 +this 33 +of 3 +no 73 +according 60 +us 45 +will 16 +i 57 +s 11 +so 56 +t 52 +were 42 +the 1 +first 47 +out 38 +or 44 +are 17 diff --git a/models/recall/word2vec/data/test/sample.txt b/models/recall/word2vec/data/test/sample.txt new file mode 100644 index 00000000..0b211c91 --- /dev/null +++ b/models/recall/word2vec/data/test/sample.txt @@ -0,0 +1,200 @@ +Athens Greece Baghdad Iraq +Athens Greece Bangkok Thailand +Athens Greece Beijing China +Athens Greece Berlin Germany +Athens Greece Bern Switzerland +Athens Greece Cairo Egypt +Athens Greece Canberra Australia +Athens Greece Hanoi Vietnam +Athens Greece Havana Cuba +Athens Greece Helsinki Finland +Athens Greece Islamabad Pakistan +Athens Greece Kabul Afghanistan +Athens Greece London England +Athens Greece Madrid Spain +Athens Greece Moscow Russia +Athens Greece Oslo Norway +Athens Greece Ottawa Canada +Athens Greece Paris France +Athens Greece Rome Italy +Athens Greece Stockholm Sweden +Athens Greece Tehran Iran +Athens Greece Tokyo Japan +Baghdad Iraq Bangkok Thailand +Baghdad Iraq Beijing China +Baghdad Iraq Berlin Germany +Baghdad Iraq Bern Switzerland +Baghdad Iraq Cairo Egypt +Baghdad Iraq Canberra Australia +Baghdad Iraq Hanoi Vietnam +Baghdad Iraq Havana Cuba +Baghdad Iraq Helsinki Finland +Baghdad Iraq Islamabad Pakistan +Baghdad Iraq Kabul Afghanistan +Baghdad Iraq London England +Baghdad Iraq Madrid Spain +Baghdad Iraq Moscow Russia +Baghdad Iraq Oslo Norway +Baghdad Iraq Ottawa Canada +Baghdad Iraq Paris France +Baghdad Iraq Rome Italy +Baghdad Iraq Stockholm Sweden +Baghdad Iraq Tehran Iran +Baghdad Iraq Tokyo Japan +Baghdad Iraq Athens Greece +Bangkok Thailand Beijing China +Bangkok Thailand Berlin Germany +Bangkok Thailand Bern Switzerland +Bangkok Thailand Cairo Egypt +Bangkok Thailand Canberra Australia +Bangkok Thailand Hanoi Vietnam +Bangkok Thailand Havana Cuba +Bangkok Thailand Helsinki Finland +Bangkok Thailand Islamabad Pakistan +Bangkok Thailand Kabul Afghanistan +Bangkok Thailand London England +Bangkok Thailand Madrid Spain +Bangkok Thailand Moscow Russia +Bangkok Thailand Oslo Norway +Bangkok Thailand Ottawa Canada +Bangkok Thailand Paris France +Bangkok Thailand Rome Italy +Bangkok Thailand Stockholm Sweden +Bangkok Thailand Tehran Iran +Bangkok Thailand Tokyo Japan +Bangkok Thailand Athens Greece +Bangkok Thailand Baghdad Iraq +Beijing China Berlin Germany +Beijing China Bern Switzerland +Beijing China Cairo Egypt +Beijing China Canberra Australia +Beijing China Hanoi Vietnam +Beijing China Havana Cuba +Beijing China Helsinki Finland +Beijing China Islamabad Pakistan +Beijing China Kabul Afghanistan +Beijing China London England +Beijing China Madrid Spain +Beijing China Moscow Russia +Beijing China Oslo Norway +Beijing China Ottawa Canada +Beijing China Paris France +Beijing China Rome Italy +Beijing China Stockholm Sweden +Beijing China Tehran Iran +Beijing China Tokyo Japan +Beijing China Athens Greece +Beijing China Baghdad Iraq +Beijing China Bangkok Thailand +Berlin Germany Bern Switzerland +Berlin Germany Cairo Egypt +Berlin Germany Canberra Australia +Berlin Germany Hanoi Vietnam +Berlin Germany Havana Cuba +Berlin Germany Helsinki Finland +Berlin Germany Islamabad Pakistan +Berlin Germany Kabul Afghanistan +Berlin Germany London England +Berlin Germany Madrid Spain +Berlin Germany Moscow Russia +Berlin Germany Oslo Norway +Berlin Germany Ottawa Canada +Berlin Germany Paris France +Berlin Germany Rome Italy +Berlin Germany Stockholm Sweden +Berlin Germany Tehran Iran +Berlin Germany Tokyo Japan +Berlin Germany Athens Greece +Berlin Germany Baghdad Iraq +Berlin Germany Bangkok Thailand +Berlin Germany Beijing China +Bern Switzerland Cairo Egypt +Bern Switzerland Canberra Australia +Bern Switzerland Hanoi Vietnam +Bern Switzerland Havana Cuba +Bern Switzerland Helsinki Finland +Bern Switzerland Islamabad Pakistan +Bern Switzerland Kabul Afghanistan +Bern Switzerland London England +Bern Switzerland Madrid Spain +Bern Switzerland Moscow Russia +Bern Switzerland Oslo Norway +Bern Switzerland Ottawa Canada +Bern Switzerland Paris France +Bern Switzerland Rome Italy +Bern Switzerland Stockholm Sweden +Bern Switzerland Tehran Iran +Bern Switzerland Tokyo Japan +Bern Switzerland Athens Greece +Bern Switzerland Baghdad Iraq +Bern Switzerland Bangkok Thailand +Bern Switzerland Beijing China +Bern Switzerland Berlin Germany +Cairo Egypt Canberra Australia +Cairo Egypt Hanoi Vietnam +Cairo Egypt Havana Cuba +Cairo Egypt Helsinki Finland +Cairo Egypt Islamabad Pakistan +Cairo Egypt Kabul Afghanistan +Cairo Egypt London England +Cairo Egypt Madrid Spain +Cairo Egypt Moscow Russia +Cairo Egypt Oslo Norway +Cairo Egypt Ottawa Canada +Cairo Egypt Paris France +Cairo Egypt Rome Italy +Cairo Egypt Stockholm Sweden +Cairo Egypt Tehran Iran +Cairo Egypt Tokyo Japan +Cairo Egypt Athens Greece +Cairo Egypt Baghdad Iraq +Cairo Egypt Bangkok Thailand +Cairo Egypt Beijing China +Cairo Egypt Berlin Germany +Cairo Egypt Bern Switzerland +Canberra Australia Hanoi Vietnam +Canberra Australia Havana Cuba +Canberra Australia Helsinki Finland +Canberra Australia Islamabad Pakistan +Canberra Australia Kabul Afghanistan +Canberra Australia London England +Canberra Australia Madrid Spain +Canberra Australia Moscow Russia +Canberra Australia Oslo Norway +Canberra Australia Ottawa Canada +Canberra Australia Paris France +Canberra Australia Rome Italy +Canberra Australia Stockholm Sweden +Canberra Australia Tehran Iran +Canberra Australia Tokyo Japan +Canberra Australia Athens Greece +Canberra Australia Baghdad Iraq +Canberra Australia Bangkok Thailand +Canberra Australia Beijing China +Canberra Australia Berlin Germany +Canberra Australia Bern Switzerland +Canberra Australia Cairo Egypt +Hanoi Vietnam Havana Cuba +Hanoi Vietnam Helsinki Finland +Hanoi Vietnam Islamabad Pakistan +Hanoi Vietnam Kabul Afghanistan +Hanoi Vietnam London England +Hanoi Vietnam Madrid Spain +Hanoi Vietnam Moscow Russia +Hanoi Vietnam Oslo Norway +Hanoi Vietnam Ottawa Canada +Hanoi Vietnam Paris France +Hanoi Vietnam Rome Italy +Hanoi Vietnam Stockholm Sweden +Hanoi Vietnam Tehran Iran +Hanoi Vietnam Tokyo Japan +Hanoi Vietnam Athens Greece +Hanoi Vietnam Baghdad Iraq +Hanoi Vietnam Bangkok Thailand +Hanoi Vietnam Beijing China +Hanoi Vietnam Berlin Germany +Hanoi Vietnam Bern Switzerland +Hanoi Vietnam Cairo Egypt +Hanoi Vietnam Canberra Australia +Havana Cuba Helsinki Finland +Havana Cuba Islamabad Pakistan diff --git a/models/recall/word2vec/data/train/convert_sample.txt b/models/recall/word2vec/data/train/convert_sample.txt new file mode 100644 index 00000000..d6c7b789 --- /dev/null +++ b/models/recall/word2vec/data/train/convert_sample.txt @@ -0,0 +1,195 @@ +45 8 71 53 83 58 71 28 46 3 +59 68 5 82 0 81 +61 +52 +80 2 4 +18 +0 45 10 10 0 8 45 5 0 10 +16 16 14 10 +71 73 23 32 16 0 49 53 +67 6 26 5 +18 37 30 65 +16 75 30 +1 42 25 +54 43 0 6 0 10 0 66 +20 13 7 49 5 46 37 0 +32 1 40 55 74 +16 14 3 +76 29 14 3 44 13 42 44 34 3 +4 80 +32 37 0 3 0 22 6 8 3 62 +13 75 9 6 65 79 8 24 0 24 6 73 +81 0 +79 7 40 14 5 6 +58 56 38 +23 14 6 2 51 +12 24 6 +18 37 55 +0 14 43 50 +52 53 +22 19 11 6 6 41 +20 68 7 66 59 66 31 48 +31 2 70 15 24 24 44 72 68 14 27 6 +2 +28 10 +35 51 6 0 64 17 4 21 13 0 11 +9 33 43 +26 4 +4 +69 29 4 8 +0 76 46 0 51 30 34 20 79 22 +1 49 9 25 0 25 78 +10 +81 57 81 72 +8 34 31 29 37 +38 13 +9 5 6 39 54 43 +81 70 18 2 53 +55 7 44 21 30 0 60 +19 23 3 0 39 82 +28 56 27 4 38 55 2 +41 17 0 43 6 21 41 27 +70 29 59 +5 36 36 31 26 17 8 39 78 +28 64 11 8 21 41 11 16 7 16 20 +8 +13 40 61 68 +9 +57 40 72 7 71 29 2 22 29 38 1 30 +0 3 +39 0 4 5 39 21 41 5 54 45 +22 7 1 1 0 0 +46 0 0 20 40 29 3 +11 0 78 4 15 82 51 +0 2 33 +0 21 41 19 29 2 59 36 +27 3 14 0 +32 63 84 63 84 3 63 84 0 63 84 +36 13 +13 15 +36 57 35 34 54 0 +13 22 31 5 +3 78 2 2 +27 11 57 20 20 11 +67 +28 70 44 58 0 28 17 7 17 29 +53 11 62 17 6 17 12 30 +32 81 +80 0 35 22 19 6 35 51 +55 33 76 0 9 0 +0 +56 +52 +42 62 0 +50 1 34 38 0 58 21 +54 62 0 10 +13 1 42 25 4 3 3 0 0 +25 26 9 +28 18 39 +4 49 77 32 49 33 +13 0 +6 11 56 52 10 +15 12 74 1 8 45 44 8 0 14 12 6 12 9 8 45 0 44 76 4 3 12 11 0 +35 48 23 1 0 +8 5 54 15 5 1 +20 38 0 48 7 30 0 17 29 32 76 14 +8 46 37 +64 53 0 0 24 0 13 6 +0 52 0 1 3 0 +55 1 43 24 34 24 71 28 42 1 +83 15 57 46 24 3 40 14 +61 47 23 1 +31 0 26 24 25 36 16 27 12 11 33 25 43 +20 34 57 52 2 70 56 7 57 52 +44 +62 26 69 8 +74 1 6 51 33 74 +49 0 22 1 0 17 32 14 21 22 3 45 26 10 +5 78 +64 35 18 75 5 80 0 24 53 26 0 +48 83 +79 61 +60 1 23 9 10 +50 3 1 3 24 0 +1 47 27 30 67 4 83 61 +32 15 69 36 19 6 7 42 +34 47 33 68 +63 16 38 11 67 +1 50 4 +65 27 78 +27 48 39 16 14 76 +13 0 42 34 36 20 19 33 +7 19 31 37 25 +5 42 64 +4 42 23 8 +77 50 4 31 5 +9 14 5 0 3 27 +19 27 1 40 2 1 77 40 29 14 2 1 25 69 +33 73 7 18 25 35 29 14 58 0 0 35 70 +23 6 4 +53 3 0 46 4 74 58 42 1 +35 27 77 8 4 77 0 +0 17 48 0 0 6 22 +19 0 2 43 59 0 61 +20 71 79 20 14 41 1 +37 73 65 +9 3 +0 5 10 +0 6 42 8 47 74 +23 9 18 62 23 +47 +39 50 0 50 26 +69 26 66 38 14 72 15 1 +6 21 33 65 24 9 2 +60 +16 25 22 16 15 0 +18 37 0 28 50 +40 75 5 +36 66 11 38 0 3 36 +5 26 59 66 0 +45 10 6 7 31 21 41 27 4 +72 30 10 4 0 83 2 30 47 67 33 +17 6 64 29 0 +0 30 38 12 5 18 4 0 +60 83 3 55 3 +0 4 0 33 43 80 8 75 +5 77 0 22 30 21 41 27 +36 19 3 0 82 49 +6 32 17 +0 10 0 62 8 82 +54 11 38 4 2 19 7 35 18 39 0 16 14 +37 0 47 75 +61 0 58 1 48 33 32 10 +47 10 73 47 +17 34 2 +7 56 28 0 2 +39 23 15 15 6 13 +9 15 0 13 45 2 14 15 +0 11 0 0 72 11 13 5 26 3 0 +0 19 38 12 1 3 +67 12 +36 26 0 5 +56 60 18 37 1 44 +11 13 11 40 12 +19 56 +57 0 22 40 35 0 51 6 28 28 13 +73 +34 22 65 64 28 52 44 +13 1 25 63 84 6 7 12 41 63 84 +69 46 4 0 +17 3 +0 3 13 55 3 26 46 +2 2 21 7 67 +45 34 0 14 21 60 2 +80 11 18 34 29 60 4 14 +48 +27 +21 41 0 66 34 +54 43 0 0 +79 68 13 23 5 51 8 +0 49 31 23 4 +59 20 48 35 16 5 8 +22 0 8 26 49 39 10 +37 4 24 0 5 6 65 68 11 0 +11 0 2 25 7 +3 82 18 0 diff --git a/models/recall/word2vec/model.py b/models/recall/word2vec/model.py new file mode 100644 index 00000000..9c35cb81 --- /dev/null +++ b/models/recall/word2vec/model.py @@ -0,0 +1,202 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy as np +import paddle.fluid as fluid + +from fleetrec.core.utils import envs +from fleetrec.core.model import Model as ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def input(self): + neg_num = int(envs.get_global_env("hyper_parameters.neg_num", None, self._namespace)) + self.input_word = fluid.data(name="input_word", shape=[None, 1], dtype='int64') + self.true_word = fluid.data(name='true_label', shape=[None, 1], dtype='int64') + self._data_var.append(self.input_word) + self._data_var.append(self.true_word) + with_shuffle_batch = bool(int(envs.get_global_env("hyper_parameters.with_shuffle_batch", None, self._namespace))) + if not with_shuffle_batch: + self.neg_word = fluid.data(name="neg_label", shape=[None, neg_num], dtype='int64') + self._data_var.append(self.neg_word) + + if self._platform != "LINUX": + self._data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False) + + def net(self): + is_distributed = True if envs.get_trainer() == "CtrTrainer" else False + neg_num = int(envs.get_global_env("hyper_parameters.neg_num", None, self._namespace)) + sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace) + sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace) + with_shuffle_batch = bool(int(envs.get_global_env("hyper_parameters.with_shuffle_batch", None, self._namespace))) + + def embedding_layer(input, table_name, emb_dim, initializer_instance=None, squeeze=False): + emb = fluid.embedding( + input=input, + is_sparse=True, + is_distributed=is_distributed, + size=[sparse_feature_number, emb_dim], + param_attr=fluid.ParamAttr( + name=table_name, + initializer=initializer_instance), + ) + if squeeze: + return fluid.layers.squeeze(input=emb, axes=[1]) + else: + return emb + + init_width = 0.5 / sparse_feature_dim + emb_initializer = fluid.initializer.Uniform(-init_width, init_width) + emb_w_initializer = fluid.initializer.Constant(value=0.0) + + input_emb = embedding_layer(self.input_word, "emb", sparse_feature_dim, emb_initializer, True) + true_emb_w = embedding_layer(self.true_word, "emb_w", sparse_feature_dim, emb_w_initializer, True) + true_emb_b = embedding_layer(self.true_word, "emb_b", 1, emb_w_initializer, True) + + if with_shuffle_batch: + neg_emb_w_list = [] + for i in range(neg_num): + neg_emb_w_list.append(fluid.contrib.layers.shuffle_batch(true_emb_w)) # shuffle true_word + neg_emb_w_concat = fluid.layers.concat(neg_emb_w_list, axis=0) + neg_emb_w = fluid.layers.reshape(neg_emb_w_concat, shape=[-1, neg_num, sparse_feature_dim]) + + neg_emb_b_list = [] + for i in range(neg_num): + neg_emb_b_list.append(fluid.contrib.layers.shuffle_batch(true_emb_b)) # shuffle true_word + neg_emb_b = fluid.layers.concat(neg_emb_b_list, axis=0) + neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num]) + + else: + neg_emb_w = embedding_layer(self.neg_word, "emb_w", sparse_feature_dim, emb_w_initializer) + neg_emb_b = embedding_layer(self.neg_word, "emb_b", 1, emb_w_initializer) + neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num]) + + true_logits = fluid.layers.elementwise_add( + fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(input_emb, true_emb_w), + dim=1, + keep_dim=True), + true_emb_b) + + input_emb_re = fluid.layers.reshape( + input_emb, shape=[-1, 1, sparse_feature_dim]) + neg_matmul = fluid.layers.matmul(input_emb_re, neg_emb_w, transpose_y=True) + neg_logits = fluid.layers.elementwise_add( + fluid.layers.reshape(neg_matmul, shape=[-1, neg_num]), + neg_emb_b_vec) + + label_ones = fluid.layers.fill_constant_batch_size_like( + true_logits, shape=[-1, 1], value=1.0, dtype='float32') + label_zeros = fluid.layers.fill_constant_batch_size_like( + true_logits, shape=[-1, neg_num], value=0.0, dtype='float32') + + true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits, + label_ones) + neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits, + label_zeros) + cost = fluid.layers.elementwise_add( + fluid.layers.reduce_sum( + true_xent, dim=1), + fluid.layers.reduce_sum( + neg_xent, dim=1)) + self.avg_cost = fluid.layers.reduce_mean(cost) + global_right_cnt = fluid.layers.create_global_var(name="global_right_cnt", persistable=True, dtype='float32', shape=[1], value=0) + global_total_cnt = fluid.layers.create_global_var(name="global_total_cnt", persistable=True, dtype='float32', shape=[1], value=0) + global_right_cnt.stop_gradient = True + global_total_cnt.stop_gradient = True + + def avg_loss(self): + self._cost = self.avg_cost + + def metrics(self): + self._metrics["LOSS"] = self.avg_cost + + def train_net(self): + self.input() + self.net() + self.avg_loss() + self.metrics() + + def optimizer(self): + learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace) + decay_steps = envs.get_global_env("hyper_parameters.decay_steps", None, self._namespace) + decay_rate = envs.get_global_env("hyper_parameters.decay_rate", None, self._namespace) + optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.exponential_decay( + learning_rate=learning_rate, + decay_steps=decay_steps, + decay_rate=decay_rate, + staircase=True)) + return optimizer + + def analogy_input(self): + sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace) + self.analogy_a = fluid.data(name="analogy_a", shape=[None], dtype='int64') + self.analogy_b = fluid.data(name="analogy_b", shape=[None], dtype='int64') + self.analogy_c = fluid.data(name="analogy_c", shape=[None], dtype='int64') + self.analogy_d = fluid.data(name="analogy_d", shape=[None], dtype='int64') + self._infer_data_var = [self.analogy_a, self.analogy_b, self.analogy_c, self.analogy_d] + + self._infer_data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False) + + def infer_net(self): + sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace) + sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace) + + def embedding_layer(input, table_name, initializer_instance=None): + emb = fluid.embedding( + input=input, + size=[sparse_feature_number, sparse_feature_dim], + param_attr=table_name) + return emb + + self.analogy_input() + all_label = np.arange(sparse_feature_number).reshape(sparse_feature_number).astype('int32') + self.all_label = fluid.layers.cast(x=fluid.layers.assign(all_label), dtype='int64') + emb_all_label = embedding_layer(self.all_label, "emb") + emb_a = embedding_layer(self.analogy_a, "emb") + emb_b = embedding_layer(self.analogy_b, "emb") + emb_c = embedding_layer(self.analogy_c, "emb") + + target = fluid.layers.elementwise_add( + fluid.layers.elementwise_sub(emb_b, emb_a), emb_c) + + emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) + dist = fluid.layers.matmul(x=target, y=emb_all_label_l2, transpose_y=True) + values, pred_idx = fluid.layers.topk(input=dist, k=4) + label = fluid.layers.expand(fluid.layers.unsqueeze(self.analogy_d, axes=[1]), expand_times=[1, 4]) + label_ones = fluid.layers.fill_constant_batch_size_like( + label, shape=[-1, 1], value=1.0, dtype='float32') + right_cnt = fluid.layers.reduce_sum( + input=fluid.layers.cast(fluid.layers.equal(pred_idx, label), dtype='float32')) + total_cnt = fluid.layers.reduce_sum(label_ones) + + global_right_cnt = fluid.layers.create_global_var(name="global_right_cnt", persistable=True, dtype='float32', shape=[1], value=0) + global_total_cnt = fluid.layers.create_global_var(name="global_total_cnt", persistable=True, dtype='float32', shape=[1], value=0) + global_right_cnt.stop_gradient = True + global_total_cnt.stop_gradient = True + + tmp1 = fluid.layers.elementwise_add(right_cnt, global_right_cnt) + fluid.layers.assign(tmp1, global_right_cnt) + tmp2 = fluid.layers.elementwise_add(total_cnt, global_total_cnt) + fluid.layers.assign(tmp2, global_total_cnt) + + acc = fluid.layers.elementwise_div(global_right_cnt, global_total_cnt, name="total_acc") + self._infer_results['acc'] = acc -- GitLab