提交 e858dff8 编写于 作者: M malin10

bug fix for w2v

上级 8ef34ea6
......@@ -68,6 +68,8 @@ class DataLoader(DatasetBase):
reader_ins = SlotReader(context["config_yaml"])
if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader)
elif hasattr(reader_ins, 'batch_tensor_creator'):
dataloader.set_batch_generator(reader)
else:
dataloader.set_sample_generator(reader, batch_size)
return dataloader
......
......@@ -67,6 +67,10 @@ def dataloader_by_name(readerclass,
if hasattr(reader, 'generate_batch_from_trainfiles'):
return gen_batch_reader()
if hasattr(reader, "batch_tensor_creator"):
return reader.batch_tensor_creator(gen_reader)
return gen_reader
......
......@@ -19,6 +19,8 @@
├── data_prepare.sh #一键数据处理脚本
├── w2v_reader.py #训练数据reader
├── w2v_evaluate_reader.py # 预测数据reader
├── infer.py # 自定义预测脚本
├── utils.py # 自定义预测中用到的reader等工具
```
注:在阅读该示例前,建议您先了解以下内容:
......@@ -154,9 +156,12 @@ runner:
phases: [phase1]
```
### 单机预测
我们通过词类比(Word Analogy)任务来检验word2vec模型的训练效果。输入四个词A,B,C,D,假设存在一种关系relation, 使得relation(A, B) = relation(C, D),然后通过A,B,C去预测D,emb(D) = emb(B) - emb(A) + emb(C)。
CPU环境
PaddleRec预测配置:
在config.yaml文件中设置好epochs、device等参数。
```
......@@ -168,6 +173,10 @@ CPU环境
print_interval: 1
phases: [phase2]
```
为复现论文效果,我们提供了一个自定义预测脚本,自定义预测中,我们会跳过预测结果是输入A,B,C的情况,计算预测准确率。执行命令如下:
```
python infer.py --test_dir ./data/test --dict_path ./data/dict/word_id_dict.txt --batch_size 20000 --model_dir ./increment_w2v/ --start_index 0 --last_index 5 --emb_size 300
```
### 运行
```
......@@ -212,13 +221,12 @@ Infer phase2 of epoch 3 done, use time: 4.43099021912, global metrics: acc=[1.]
- batch_size: 修改config.yaml中dataset_train数据集的batch_size为100。
- epochs: 修改config.yaml中runner的epochs为5。
使用cpu训练 5轮 测试Recall@20:0.540
修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行
```
python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径
```
使用cpu训练5轮,自定义测试(跳过输入)准确率为0.540。
## 进阶使用
## FAQ
......@@ -22,7 +22,7 @@ dataset:
word_count_dict_path: "{workspace}/data/dict/word_count_dict.txt"
data_converter: "{workspace}/w2v_reader.py"
- name: dataset_infer # name
batch_size: 50
batch_size: 2000
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/test"
word_id_dict_path: "{workspace}/data/dict/word_id_dict.txt"
......@@ -59,7 +59,7 @@ runner:
save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path
print_interval: 1
print_interval: 1000
phases: [phase1]
- name: single_cpu_infer
class: infer
......
......@@ -25,7 +25,7 @@ mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tok
python preprocess.py --build_dict --build_dict_corpus_dir raw_data/training-monolingual.tokenized.shuffled --dict_path raw_data/word_count_dict.txt
python preprocess.py --filter_corpus --dict_path raw_data/word_count_dict.txt --input_corpus_dir raw_data/training-monolingual.tokenized.shuffled --output_corpus_dir raw_data/convert_text8 --min_count 5 --downsample 0.001
mv raw_data/word_count_dict.txt data/dict/
mv raw_data/word_id_dict.txt data/dict/
mv raw_data/word_count_dict.txt_word_to_id_ data/dict/word_id_dict.txt
rm -rf data/train/*
rm -rf data/test/*
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
import time
import math
import numpy as np
import six
import paddle.fluid as fluid
import paddle
import utils
if six.PY2:
reload(sys)
sys.setdefaultencoding('utf-8')
def parse_args():
parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example")
parser.add_argument(
'--dict_path',
type=str,
default='./data/data_c/1-billion_dict_word_to_id_',
help="The path of dic")
parser.add_argument(
'--test_dir', type=str, default='test_data', help='test file address')
parser.add_argument(
'--print_step', type=int, default='500000', help='print step')
parser.add_argument(
'--start_index', type=int, default='0', help='start index')
parser.add_argument(
'--last_index', type=int, default='100', help='last index')
parser.add_argument(
'--model_dir', type=str, default='model', help='model dir')
parser.add_argument(
'--use_cuda', type=int, default='0', help='whether use cuda')
parser.add_argument(
'--batch_size', type=int, default='5', help='batch_size')
parser.add_argument(
'--emb_size', type=int, default='64', help='batch_size')
args = parser.parse_args()
return args
def infer_network(vocab_size, emb_size):
analogy_a = fluid.data(name="analogy_a", shape=[None], dtype='int64')
analogy_b = fluid.data(name="analogy_b", shape=[None], dtype='int64')
analogy_c = fluid.data(name="analogy_c", shape=[None], dtype='int64')
all_label = fluid.data(name="all_label", shape=[vocab_size], dtype='int64')
emb_all_label = fluid.embedding(
input=all_label, size=[vocab_size, emb_size], param_attr="emb")
emb_a = fluid.embedding(
input=analogy_a, size=[vocab_size, emb_size], param_attr="emb")
emb_b = fluid.embedding(
input=analogy_b, size=[vocab_size, emb_size], param_attr="emb")
emb_c = fluid.embedding(
input=analogy_c, size=[vocab_size, emb_size], param_attr="emb")
target = fluid.layers.elementwise_add(
fluid.layers.elementwise_sub(emb_b, emb_a), emb_c)
emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1)
dist = fluid.layers.matmul(x=target, y=emb_all_label_l2, transpose_y=True)
values, pred_idx = fluid.layers.topk(input=dist, k=4)
return values, pred_idx
def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
emb_size = args.emb_size
batch_size = args.batch_size
with fluid.scope_guard(fluid.Scope()):
main_program = fluid.Program()
with fluid.program_guard(main_program):
values, pred = infer_network(vocab_size, emb_size)
for epoch in range(start_index, last_index + 1):
copy_program = main_program.clone()
model_path = model_dir + "/" + str(epoch)
fluid.io.load_persistables(
exe, model_path, main_program=copy_program)
accum_num = 0
accum_num_sum = 0.0
t0 = time.time()
step_id = 0
for data in test_reader():
step_id += 1
b_size = len([dat[0] for dat in data])
wa = np.array([dat[0] for dat in data]).astype(
"int64").reshape(b_size)
wb = np.array([dat[1] for dat in data]).astype(
"int64").reshape(b_size)
wc = np.array([dat[2] for dat in data]).astype(
"int64").reshape(b_size)
label = [dat[3] for dat in data]
input_word = [dat[4] for dat in data]
para = exe.run(copy_program,
feed={
"analogy_a": wa,
"analogy_b": wb,
"analogy_c": wc,
"all_label": np.arange(vocab_size)
.reshape(vocab_size).astype("int64"),
},
fetch_list=[pred.name, values],
return_numpy=False)
pre = np.array(para[0])
val = np.array(para[1])
for ii in range(len(label)):
top4 = pre[ii]
accum_num_sum += 1
for idx in top4:
if int(idx) in input_word[ii]:
continue
if int(idx) == int(label[ii][0]):
accum_num += 1
break
if step_id % 1 == 0:
print("step:%d %d " % (step_id, accum_num))
print("epoch:%d \t acc:%.3f " %
(epoch, 1.0 * accum_num / accum_num_sum))
if __name__ == "__main__":
args = parse_args()
start_index = args.start_index
last_index = args.last_index
test_dir = args.test_dir
model_dir = args.model_dir
batch_size = args.batch_size
dict_path = args.dict_path
use_cuda = True if args.use_cuda else False
print("start index: ", start_index, " last_index:", last_index)
vocab_size, test_reader, id2word = utils.prepare_data(
test_dir, dict_path, batch_size=batch_size)
print("vocab_size:", vocab_size)
infer_epoch(
args,
vocab_size,
test_reader=test_reader,
use_cuda=use_cuda,
i2w=id2word)
......@@ -209,10 +209,10 @@ class Model(ModelBase):
emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1)
dist = fluid.layers.matmul(
x=target, y=emb_all_label_l2, transpose_y=True)
values, pred_idx = fluid.layers.topk(input=dist, k=4)
values, pred_idx = fluid.layers.topk(input=dist, 1)
label = fluid.layers.expand(
fluid.layers.unsqueeze(
inputs[3], axes=[1]), expand_times=[1, 4])
inputs[3], axes=[1]), expand_times=[1, 1])
label_ones = fluid.layers.fill_constant_batch_size_like(
label, shape=[-1, 1], value=1.0, dtype='float32')
right_cnt = fluid.layers.reduce_sum(input=fluid.layers.cast(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import collections
import six
import time
import numpy as np
import paddle.fluid as fluid
import paddle
import os
import preprocess
import io
def BuildWord_IdMap(dict_path):
word_to_id = dict()
id_to_word = dict()
with io.open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word_to_id[line.split(' ')[0]] = int(line.split(' ')[1])
id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]
return word_to_id, id_to_word
def prepare_data(file_dir, dict_path, batch_size):
w2i, i2w = BuildWord_IdMap(dict_path)
vocab_size = len(i2w)
reader = fluid.io.batch(test(file_dir, w2i), batch_size)
return vocab_size, reader, i2w
def check_version(with_shuffle_batch=False):
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
if with_shuffle_batch:
fluid.require_version('1.7.0')
else:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
def native_to_unicode(s):
if _is_unicode(s):
return s
try:
return _to_unicode(s)
except UnicodeDecodeError:
res = _to_unicode(s, ignore_errors=True)
return res
def _is_unicode(s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def _to_unicode(s, ignore_errors=False):
if _is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def strip_lines(line, vocab):
return _replace_oov(vocab, native_to_unicode(line))
def _replace_oov(original_vocab, line):
"""Replace out-of-vocab words with "<UNK>".
This maintains compatibility with published results.
Args:
original_vocab: a set of strings (The standard vocabulary for the dataset)
line: a unicode string - a space-delimited sequence of words.
Returns:
a unicode string - a space-delimited sequence of words.
"""
return u" ".join([
word if word in original_vocab else u"<UNK>" for word in line.split()
])
def reader_creator(file_dir, word_to_id):
def reader():
files = os.listdir(file_dir)
for fi in files:
with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for line in f:
if ':' in line:
pass
else:
line = strip_lines(line.lower(), word_to_id)
line = line.split()
yield [word_to_id[line[0]]], [word_to_id[line[1]]], [
word_to_id[line[2]]
], [word_to_id[line[3]]], [
word_to_id[line[0]], word_to_id[line[1]],
word_to_id[line[2]]
]
return reader
def test(test_dir, w2i):
return reader_creator(test_dir, w2i)
......@@ -76,7 +76,7 @@ class Reader(ReaderBase):
def generate_sample(self, line):
def reader():
if ':' in line:
pass
return
features = self.strip_lines(line.lower(), self.word_to_id)
features = features.split()
yield [('analogy_a', [self.word_to_id[features[0]]]),
......
......@@ -15,6 +15,7 @@
import io
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.reader import ReaderBase
from paddlerec.core.utils import envs
......@@ -47,6 +48,10 @@ class Reader(ReaderBase):
self.with_shuffle_batch = envs.get_global_env(
"hyper_parameters.with_shuffle_batch")
self.random_generator = NumpyRandomInt(1, self.window_size + 1)
self.batch_size = envs.get_global_env(
"dataset.dataset_train.batch_size")
self.is_dataloader = envs.get_global_env(
"dataset.dataset_train.type") == "DataLoader"
self.cs = None
if not self.with_shuffle_batch:
......@@ -88,11 +93,46 @@ class Reader(ReaderBase):
for context_id in context_word_ids:
output = [('input_word', [int(target_id)]),
('true_label', [int(context_id)])]
if not self.with_shuffle_batch:
if self.with_shuffle_batch or self.is_dataloader:
yield output
else:
neg_array = self.cs.searchsorted(
np.random.sample(self.neg_num))
output += [('neg_label',
[int(str(i)) for i in neg_array])]
yield output
yield output
return reader
def batch_tensor_creator(self, sample_reader):
def __reader__():
result = [[], []]
for sample in sample_reader():
for i, fea in enumerate(sample):
result[i].append(fea)
if len(result[0]) == self.batch_size:
tensor_result = []
for tensor in result:
t = fluid.Tensor()
dat = np.array(tensor, dtype='int64')
if len(dat.shape) > 2:
dat = dat.reshape((dat.shape[0], dat.shape[2]))
elif len(dat.shape) == 1:
dat = dat.reshape((-1, 1))
t.set(dat, fluid.CPUPlace())
tensor_result.append(t)
if self.with_shuffle_batch:
yield tensor_result
else:
tt = fluid.Tensor()
neg_array = self.cs.searchsorted(
np.random.sample(self.neg_num))
neg_array = np.tile(neg_array, self.batch_size)
tt.set(
neg_array.reshape((self.batch_size, self.neg_num)),
fluid.CPUPlace())
tensor_result.append(tt)
yield tensor_result
result = [[], []]
return __reader__
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册