未验证 提交 fc5a4e1a 编写于 作者: J Jiabin Yang 提交者: GitHub

Merge pull request #1448 from jacquesqiao/add-word2vec

Add word2vec
.DS_Store paddle/operators/check_t.save
*.pyc paddle/operators/check_tensor.ls
.*~ paddle/operators/tensor.save
fluid/neural_machine_translation/transformer/deps python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/
fluid/neural_machine_translation/transformer/train.data python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/
fluid/neural_machine_translation/transformer/train.pkl python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/
fluid/neural_machine_translation/transformer/train.sh *.DS_Store
fluid/neural_machine_translation/transformer/train.tok.clean.bpe.32000.en-de *.vs
fluid/neural_machine_translation/transformer/vocab.bpe.32000.refined build/
build_doc/
*.user
.vscode
.idea
.project
.cproject
.pydevproject
.settings/
CMakeSettings.json
Makefile
.test_env/
third_party/
*~
bazel-*
third_party/
build_*
# clion workspace.
cmake-build-*
model_test
# 基于skip-gram的word2vector模型
## 介绍
## 运行环境
需要先安装PaddlePaddle Fluid
## 数据集
数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark)的数据集.
下载数据集:
```bash
cd data && ./download.sh && cd ..
```
## 模型
本例子实现了一个skip-gram模式的word2vector模型。
## 数据准备
对数据进行预处理以生成一个词典。
```bash
python preprocess.py --data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled --dict_path data/1-billion_dict
```
## 训练
训练的命令行选项可以通过`python train.py -h`列出。
### 单机训练:
```bash
python train.py \
--train_data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled \
--dict_path data/1-billion_dict \
2>&1 | tee train.log
```
### 分布式训练
本地启动一个2 trainer 2 pserver的分布式训练任务,分布式场景下训练数据会按照trainer的id进行切分,保证trainer之间的训练数据不会重叠,提高训练效率
```bash
sh cluster_train.sh
```
## 预测
在infer.py中我们在`build_test_case`方法中构造了一些test case来评估word embeding的效果:
我们输入test case( 我们目前采用的是analogical-reasoning的任务:找到A - B = C - D的结构,为此我们计算A - B + D,通过cosine距离找最近的C,计算准确率要去除候选中出现A、B、D的候选 )然后计算候选和整个embeding中所有词的余弦相似度,并且取topK(K由参数 --rank_num确定,默认为4)打印出来。
如:
对于:boy - girl + aunt = uncle
0 nearest aunt:0.89
1 nearest uncle:0.70
2 nearest grandmother:0.67
3 nearest father:0.64
您也可以在`build_test_case`方法中模仿给出的例子增加自己的测试
训练中预测:
```bash
python infer.py --infer_during_train 2>&1 | tee infer.log
```
使用某个model进行离线预测:
```bash
python infer.py --infer_once --model_output_dir ./models/[具体的models文件目录] 2>&1 | tee infer.log
```
## 在百度云上运行集群训练
1. 参考文档 [在百度云上启动Fluid分布式训练](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst) 在百度云上部署一个CPU集群。
1. 用preprocess.py处理训练数据生成train.txt。
1. 将train.txt切分成集群机器份,放到每台机器上。
1. 用上面的 `分布式训练` 中的命令行启动分布式训练任务.
# Skip-Gram Word2Vec Model
## Introduction
## Environment
You should install PaddlePaddle Fluid first.
## Dataset
The training data for the 1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark).
Download dataset:
```bash
cd data && ./download.sh && cd ..
```
## Model
This model implement a skip-gram model of word2vector.
## Data Preprocessing method
Preprocess the training data to generate a word dict.
```bash
python preprocess.py --data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled --dict_path data/1-billion_dict
```
## Train
The command line options for training can be listed by `python train.py -h`.
### Local Train:
```bash
python train.py \
--train_data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled \
--dict_path data/1-billion_dict \
2>&1 | tee train.log
```
### Distributed Train
Run a 2 pserver 2 trainer distribute training on a single machine.
In distributed training setting, training data is splited by trainer_id, so that training data
do not overlap among trainers
```bash
sh cluster_train.sh
```
## Infer
In infer.py we construct some test cases in the `build_test_case` method to evaluate the effect of word embeding:
We enter the test case (we are currently using the analogical-reasoning task: find the structure of A - B = C - D, for which we calculate A - B + D, find the nearest C by cosine distance, the calculation accuracy is removed Candidates for A, B, and D appear in the candidate) Then calculate the cosine similarity of the candidate and all words in the entire embeding, and print out the topK (K is determined by the parameter --rank_num, the default is 4).
Such as:
For: boy - girl + aunt = uncle
0 nearest aunt: 0.89
1 nearest uncle: 0.70
2 nearest grandmother: 0.67
3 nearest father:0.64
You can also add your own tests by mimicking the examples given in the `build_test_case` method.
Forecast in training:
```bash
Python infer.py --infer_during_train 2>&1 | tee infer.log
```
Use a model for offline prediction:
```bash
Python infer.py --infer_once --model_output_dir ./models/[specific models file directory] 2>&1 | tee infer.log
```
## Train on Baidu Cloud
1. Please prepare some CPU machines on Baidu Cloud following the steps in [train_on_baidu_cloud](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst)
1. Prepare dataset using preprocess.py.
1. Split the train.txt to trainer_num parts and put them on the machines.
1. Run training with the cluster train using the command in `Distributed Train` above.
#!/bin/bash
echo "WARNING: This script only for run PaddlePaddle Fluid on one node..."
echo ""
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib
export PADDLE_PSERVER_PORTS=36001,36002
export PADDLE_PSERVER_PORT_ARRAY=(36001 36002)
export PADDLE_PSERVERS=2
export PADDLE_IP=127.0.0.1
export PADDLE_TRAINERS=2
export CPU_NUM=2
export NUM_THREADS=2
export PADDLE_SYNC_MODE=TRUE
export PADDLE_IS_LOCAL=0
export FLAGS_rpc_deadline=3000000
export GLOG_logtostderr=1
export TRAIN_DATA=data/enwik8
export DICT_PATH=data/enwik8_dict
export IS_SPARSE="--is_sparse"
echo "Start PSERVER ..."
for((i=0;i<$PADDLE_PSERVERS;i++))
do
cur_port=${PADDLE_PSERVER_PORT_ARRAY[$i]}
echo "PADDLE WILL START PSERVER "$cur_port
GLOG_v=0 PADDLE_TRAINING_ROLE=PSERVER CUR_PORT=$cur_port PADDLE_TRAINER_ID=$i python -u train.py $IS_SPARSE &> pserver.$i.log &
done
echo "Start TRAINER ..."
for((i=0;i<$PADDLE_TRAINERS;i++))
do
echo "PADDLE WILL START Trainer "$i
GLOG_v=0 PADDLE_TRAINER_ID=$i PADDLE_TRAINING_ROLE=TRAINER python -u train.py $IS_SPARSE --train_data_path $TRAIN_DATA --dict_path $DICT_PATH &> trainer.$i.log &
done
\ No newline at end of file
#!/bin/bash
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar -zxvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
import paddle
import time
import os
import paddle.fluid as fluid
import numpy as np
from Queue import PriorityQueue
import logging
import argparse
from sklearn.metrics.pairwise import cosine_similarity
word_to_id = dict()
id_to_word = dict()
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Word2vec infer example")
parser.add_argument(
'--dict_path',
type=str,
default='./data/1-billion_dict',
help="The path of training dataset")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help="The path for model to store (with infer_once please set specify dir to models) (default: models)"
)
parser.add_argument(
'--rank_num',
type=int,
default=4,
help="find rank_num-nearest result for test (default: 4)")
parser.add_argument(
'--infer_once',
action='store_true',
required=False,
default=False,
help='if using infer_once, (default: False)')
parser.add_argument(
'--infer_during_train',
action='store_true',
required=False,
default=True,
help='if using infer_during_train, (default: True)')
return parser.parse_args()
def BuildWord_IdMap(dict_path):
with open(dict_path + "_word_to_id_", 'r') as f:
for line in f:
word_to_id[line.split(' ')[0]] = int(line.split(' ')[1])
id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]
def inference_prog():
fluid.layers.create_parameter(
shape=[1, 1], dtype='float32', name="embeding")
def build_test_case(emb):
emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[
'aunt']]
desc1 = "boy - girl + aunt = uncle"
emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[
word_to_id['sisters']]
desc2 = "brother - sister + sisters = brothers"
emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[
'woman']]
desc3 = "king - queen + woman = man"
emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[
word_to_id['slowly']]
desc4 = "reluctant - reluctantly + slowly = slow"
emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[
'deeper']]
desc5 = "old - older + deeper = deep"
return [[emb1, desc1], [emb2, desc2], [emb3, desc3], [emb4, desc4],
[emb5, desc5]]
def inference_test(scope, model_dir, args):
BuildWord_IdMap(args.dict_path)
logger.info("model_dir is: {}".format(model_dir + "/"))
emb = np.array(scope.find_var("embeding").get_tensor())
test_cases = build_test_case(emb)
logger.info("inference result: ====================")
for case in test_cases:
pq = topK(args.rank_num, emb, case[0])
logger.info("Test result for {}".format(case[1]))
pq_tmps = list()
for i in range(args.rank_num):
pq_tmps.append(pq.get())
for i in range(len(pq_tmps)):
logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[
pq_tmps[len(pq_tmps) - 1 - i].id], pq_tmps[len(pq_tmps) - 1 - i]
.priority))
del pq_tmps[:]
class PQ_Entry(object):
def __init__(self, cos_similarity, id):
self.priority = cos_similarity
self.id = id
def __cmp__(self, other):
return cmp(self.priority, other.priority)
def topK(k, emb, test_emb):
pq = PriorityQueue(k + 1)
if len(emb) <= k:
for i in range(len(emb)):
x = cosine_similarity([emb[i]], [test_emb])
pq.put(PQ_Entry(x, i))
return pq
for i in range(len(emb)):
x = cosine_similarity([emb[i]], [test_emb])
pq_e = PQ_Entry(x, i)
if pq.full():
pq.get()
pq.put(pq_e)
pq.get()
return pq
def infer_during_train(args):
model_file_list = list()
exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
inference_prog()
solved_new = True
while True:
time.sleep(60)
current_list = os.listdir(args.model_output_dir)
# logger.info("current_list is : {}".format(current_list))
# logger.info("model_file_list is : {}".format(model_file_list))
if set(model_file_list) == set(current_list):
if solved_new:
solved_new = False
logger.info("No New models created")
pass
else:
solved_new = True
increment_models = list()
for f in current_list:
if f not in model_file_list:
increment_models.append(f)
logger.info("increment_models is : {}".format(increment_models))
for model in increment_models:
model_dir = args.model_output_dir + "/" + model
if os.path.exists(model_dir + "/_success"):
logger.info("using models from " + model_dir)
with fluid.scope_guard(Scope):
fluid.io.load_persistables(
executor=exe, dirname=model_dir + "/")
inference_test(Scope, model_dir, args)
model_file_list = current_list
def infer_once(args):
# check models file has already been finished
if os.path.exists(args.model_output_dir + "/_success"):
logger.info("using models from " + args.model_output_dir)
exe = fluid.Executor(fluid.CPUPlace())
Scope = fluid.Scope()
inference_prog()
with fluid.scope_guard(Scope):
fluid.io.load_persistables(
executor=exe, dirname=args.model_output_dir + "/")
inference_test(Scope, args.model_output_dir, args)
if __name__ == '__main__':
args = parse_args()
# while setting infer_once please specify the dir to models file with --model_output_dir
if args.infer_once:
infer_once(args)
if args.infer_during_train:
infer_during_train(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
neural network for word2vec
"""
from __future__ import print_function
import math
import numpy as np
import paddle.fluid as fluid
def skip_gram_word2vec(dict_size,
word_frequencys,
embedding_size,
max_code_length=None,
with_hsigmoid=False,
with_nce=True,
is_sparse=False):
def nce_layer(input, label, embedding_size, num_total_classes,
num_neg_samples, sampler, word_frequencys, sample_weight):
w_param_name = "nce_w"
b_param_name = "nce_b"
w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, embedding_size],
dtype='float32',
name=w_param_name)
b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1], dtype='float32', name=b_param_name)
cost = fluid.layers.nce(input=input,
label=label,
num_total_classes=num_total_classes,
sampler=sampler,
custom_dist=word_frequencys,
sample_weight=sample_weight,
param_attr=fluid.ParamAttr(name=w_param_name),
bias_attr=fluid.ParamAttr(name=b_param_name),
num_neg_samples=num_neg_samples,
is_sparse=is_sparse)
return cost
def hsigmoid_layer(input, label, path_table, path_code, non_leaf_num,
is_sparse):
if non_leaf_num is None:
non_leaf_num = dict_size
cost = fluid.layers.hsigmoid(
input=input,
label=label,
num_classes=non_leaf_num,
path_table=path_table,
path_code=path_code,
is_custom=True,
is_sparse=is_sparse)
return cost
datas = []
input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64')
predict_word = fluid.layers.data(
name='predict_word', shape=[1], dtype='int64')
datas.append(input_word)
datas.append(predict_word)
if with_hsigmoid:
path_table = fluid.layers.data(
name='path_table',
shape=[max_code_length if max_code_length else 40],
dtype='int64')
path_code = fluid.layers.data(
name='path_code',
shape=[max_code_length if max_code_length else 40],
dtype='int64')
datas.append(path_table)
datas.append(path_code)
py_reader = fluid.layers.create_py_reader_by_data(
capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True)
words = fluid.layers.read_file(py_reader)
emb = fluid.layers.embedding(
input=words[0],
is_sparse=is_sparse,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='embeding',
initializer=fluid.initializer.Normal(scale=1 /
math.sqrt(dict_size))))
cost, cost_nce, cost_hs = None, None, None
if with_nce:
cost_nce = nce_layer(emb, words[1], embedding_size, dict_size, 5,
"uniform", word_frequencys, None)
cost = cost_nce
if with_hsigmoid:
cost_hs = hsigmoid_layer(emb, words[1], words[2], words[3], dict_size,
is_sparse)
cost = cost_hs
if with_nce and with_hsigmoid:
cost = fluid.layers.elementwise_add(cost_nce, cost_hs)
avg_cost = fluid.layers.reduce_mean(cost)
return avg_cost, py_reader
# -*- coding: utf-8 -*
import re
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="Paddle Fluid word2 vector preprocess")
parser.add_argument(
'--data_path',
type=str,
required=True,
help="The path of training dataset")
parser.add_argument(
'--dict_path',
type=str,
default='./dict',
help="The path of generated dict")
parser.add_argument(
'--freq',
type=int,
default=5,
help="If the word count is less then freq, it will be removed from dict")
parser.add_argument(
'--is_local',
action='store_true',
required=False,
default=False,
help='Local train or not, (default: False)')
return parser.parse_args()
def text_strip(text):
return re.sub("[^a-z ]", "", text)
def build_Huffman(word_count, max_code_length):
MAX_CODE_LENGTH = max_code_length
sorted_by_freq = sorted(word_count.items(), key=lambda x: x[1])
count = list()
vocab_size = len(word_count)
parent = [-1] * 2 * vocab_size
code = [-1] * MAX_CODE_LENGTH
point = [-1] * MAX_CODE_LENGTH
binary = [-1] * 2 * vocab_size
word_code_len = dict()
word_code = dict()
word_point = dict()
i = 0
for a in range(vocab_size):
count.append(word_count[sorted_by_freq[a][0]])
for a in range(vocab_size):
word_point[sorted_by_freq[a][0]] = [-1] * MAX_CODE_LENGTH
word_code[sorted_by_freq[a][0]] = [-1] * MAX_CODE_LENGTH
for k in range(vocab_size):
count.append(1e15)
pos1 = vocab_size - 1
pos2 = vocab_size
min1i = 0
min2i = 0
b = 0
for r in range(vocab_size):
if pos1 >= 0:
if count[pos1] < count[pos2]:
min1i = pos1
pos1 = pos1 - 1
else:
min1i = pos2
pos2 = pos2 + 1
else:
min1i = pos2
pos2 = pos2 + 1
if pos1 >= 0:
if count[pos1] < count[pos2]:
min2i = pos1
pos1 = pos1 - 1
else:
min2i = pos2
pos2 = pos2 + 1
else:
min2i = pos2
pos2 = pos2 + 1
count[vocab_size + r] = count[min1i] + count[min2i]
#record the parent of left and right child
parent[min1i] = vocab_size + r
parent[min2i] = vocab_size + r
binary[min1i] = 0 #left branch has code 0
binary[min2i] = 1 #right branch has code 1
for a in range(vocab_size):
b = a
i = 0
while True:
code[i] = binary[b]
point[i] = b
i = i + 1
b = parent[b]
if b == vocab_size * 2 - 2:
break
word_code_len[sorted_by_freq[a][0]] = i
word_point[sorted_by_freq[a][0]][0] = vocab_size - 2
for k in range(i):
word_code[sorted_by_freq[a][0]][i - k - 1] = code[k]
# only non-leaf nodes will be count in
if point[k] - vocab_size >= 0:
word_point[sorted_by_freq[a][0]][i - k] = point[k] - vocab_size
return word_point, word_code, word_code_len
def preprocess(data_path, dict_path, freq, is_local):
"""
proprocess the data, generate dictionary and save into dict_path.
:param data_path: the input data path.
:param dict_path: the generated dict path. the data in dict is "word count"
:param freq:
:return:
"""
# word to count
word_count = dict()
if is_local:
for i in range(1, 100):
with open(data_path + "/news.en-000{:0>2d}-of-00100".format(
i)) as f:
for line in f:
line = line.lower()
line = text_strip(line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= freq:
item_to_remove.append(item)
for item in item_to_remove:
del word_count[item]
path_table, path_code, word_code_len = build_Huffman(word_count, 40)
with open(dict_path, 'w+') as f:
for k, v in word_count.items():
f.write(str(k) + " " + str(v) + '\n')
with open(dict_path + "_ptable", 'w+') as f2:
for pk, pv in path_table.items():
f2.write(str(pk) + ":" + ' '.join((str(x) for x in pv)) + '\n')
with open(dict_path + "_pcode", 'w+') as f3:
for pck, pcv in path_table.items():
f3.write(str(pck) + ":" + ' '.join((str(x) for x in pcv)) + '\n')
if __name__ == "__main__":
args = parse_args()
preprocess(args.data_path, args.dict_path, args.freq, args.is_local)
# -*- coding: utf-8 -*
import numpy as np
import preprocess
import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
class Word2VecReader(object):
def __init__(self,
dict_path,
data_path,
filelist,
trainer_id,
trainer_num,
window_size=5):
self.window_size_ = window_size
self.data_path_ = data_path
self.filelist = filelist
self.num_non_leaf = 0
self.word_to_id_ = dict()
self.id_to_word = dict()
self.word_to_path = dict()
self.word_to_code = dict()
self.trainer_id = trainer_id
self.trainer_num = trainer_num
word_all_count = 0
word_counts = []
word_id = 0
with open(dict_path, 'r') as f:
for line in f:
word, count = line.split()[0], int(line.split()[1])
self.word_to_id_[word] = word_id
self.id_to_word[word_id] = word #build id to word dict
word_id += 1
word_counts.append(count)
word_all_count += count
with open(dict_path + "_word_to_id_", 'w+') as f6:
for k, v in self.word_to_id_.items():
f6.write(str(k) + " " + str(v) + '\n')
self.dict_size = len(self.word_to_id_)
self.word_frequencys = [
float(count) / word_all_count for count in word_counts
]
print("dict_size = " + str(
self.dict_size)) + " word_all_count = " + str(word_all_count)
with open(dict_path + "_ptable", 'r') as f2:
for line in f2:
self.word_to_path[line.split(":")[0]] = np.fromstring(
line.split(':')[1], dtype=int, sep=' ')
self.num_non_leaf = np.fromstring(
line.split(':')[1], dtype=int, sep=' ')[0]
print("word_ptable dict_size = " + str(len(self.word_to_path)))
with open(dict_path + "_pcode", 'r') as f3:
for line in f3:
self.word_to_code[line.split(":")[0]] = np.fromstring(
line.split(':')[1], dtype=int, sep=' ')
print("word_pcode dict_size = " + str(len(self.word_to_code)))
def get_context_words(self, words, idx, window_size):
"""
Get the context word list of target word.
words: the words of the current line
idx: input word index
window_size: window size
"""
target_window = np.random.randint(1, window_size + 1)
# need to keep in mind that maybe there are no enough words before the target word.
start_point = idx - target_window if (idx - target_window) > 0 else 0
end_point = idx + target_window
# context words of the target word
targets = set(words[start_point:idx] + words[idx + 1:end_point + 1])
return list(targets)
def train(self, with_hs):
def _reader():
for file in self.filelist:
with open(self.data_path_ + "/" + file, 'r') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.text_strip(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx, self.window_size_)
for context_id in context_word_ids:
yield [target_id], [context_id]
else:
pass
count += 1
def _reader_hs():
for file in self.filelist:
with open(self.data_path_ + "/" + file, 'r') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.text_strip(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx, self.window_size_)
for context_id in context_word_ids:
yield [target_id], [context_id], [
self.word_to_code[self.id_to_word[
context_id]]
], [
self.word_to_path[self.id_to_word[
context_id]]
]
else:
pass
count += 1
if not with_hs:
return _reader
else:
return _reader_hs
if __name__ == "__main__":
window_size = 10
reader = Word2VecReader("data/enwik9_dict", "data/enwik9", window_size)
i = 0
for x, y in reader.train()():
print("x: " + str(x))
print("y: " + str(y))
print("\n")
if i == 10:
exit(0)
i += 1
from __future__ import print_function
import argparse
import logging
import os
import time
import numpy as np
# disable gpu training for this example
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
from paddle.fluid.executor import global_scope
import reader
from network_conf import skip_gram_word2vec
from infer import inference_test
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Word2vec example")
parser.add_argument(
'--train_data_path',
type=str,
default='./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled',
help="The path of training dataset")
parser.add_argument(
'--dict_path',
type=str,
default='./data/1-billion_dict',
help="The path of data dict")
parser.add_argument(
'--test_data_path',
type=str,
default='./data/text8',
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=100,
help="The size of mini-batch (default:100)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
parser.add_argument(
'--embedding_size',
type=int,
default=64,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--with_hs',
action='store_true',
required=False,
default=False,
help='using hierarchical sigmoid, (default: False)')
parser.add_argument(
'--with_nce',
action='store_true',
required=False,
default=False,
help='using negtive sampling, (default: True)')
parser.add_argument(
'--max_code_length',
type=int,
default=40,
help='max code length used by hierarchical sigmoid, (default: 40)')
parser.add_argument(
'--is_sparse',
action='store_true',
required=False,
default=False,
help='embedding and nce will use sparse or not, (default: False)')
parser.add_argument(
'--with_Adam',
action='store_true',
required=False,
default=False,
help='Using Adam as optimizer or not, (default: False)')
parser.add_argument(
'--is_local',
action='store_true',
required=False,
default=False,
help='Local train or not, (default: False)')
parser.add_argument(
'--with_speed',
action='store_true',
required=False,
default=False,
help='print speed or not , (default: False)')
parser.add_argument(
'--with_infer_test',
action='store_true',
required=False,
default=False,
help='Do inference every 100 batches , (default: False)')
parser.add_argument(
'--rank_num',
type=int,
default=4,
help="find rank_num-nearest result for test (default: 4)")
return parser.parse_args()
def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train((args.with_hs or (not args.with_nce))),
buf_size=args.batch_size * 100),
batch_size=args.batch_size)
py_reader.decorate_paddle_reader(train_reader)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exec_strategy = fluid.ExecutionStrategy()
print("CPU_NUM:" + str(os.getenv("CPU_NUM")))
exec_strategy.num_threads = int(os.getenv("CPU_NUM"))
build_strategy = fluid.BuildStrategy()
if int(os.getenv("CPU_NUM")) > 1:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
train_exe = fluid.ParallelExecutor(
use_cuda=False,
loss_name=loss.name,
main_program=train_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
profile_state = "CPU"
profiler_step = 0
profiler_step_start = 20
profiler_step_end = 30
for pass_id in range(args.num_passes):
epoch_start = time.time()
py_reader.start()
batch_id = 0
start = time.clock()
try:
while True:
if profiler_step == profiler_step_start:
fluid.profiler.start_profiler(profile_state)
loss_val = train_exe.run(fetch_list=[loss.name])
loss_val = np.mean(loss_val)
if profiler_step == profiler_step_end:
fluid.profiler.stop_profiler('total', 'trainer_profile.log')
profiler_step += 1
else:
profiler_step += 1
if batch_id % 50 == 0:
logger.info(
"TRAIN --> pass: {} batch: {} loss: {} reader queue:{}".
format(pass_id, batch_id,
loss_val.mean() / args.batch_size,
py_reader.queue.size()))
if args.with_speed:
if batch_id % 1000 == 0 and batch_id != 0:
elapsed = (time.clock() - start)
start = time.clock()
samples = 1001 * args.batch_size * int(
os.getenv("CPU_NUM"))
logger.info("Time used: {}, Samples/Sec: {}".format(
elapsed, samples / elapsed))
# calculate infer result each 100 batches when using --with_infer_test
if args.with_infer_test:
if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(
batch_id)
inference_test(global_scope(), model_dir, args)
if batch_id % 500000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(
batch_id)
fluid.io.save_persistables(executor=exe, dirname=model_dir)
with open(model_dir + "/_success", 'w+') as f:
f.write(str(batch_id))
batch_id += 1
except fluid.core.EOFException:
py_reader.reset()
epoch_end = time.time()
logger.info("Epoch: {0}, Train total expend: {1} ".format(
pass_id, epoch_end - epoch_start))
model_dir = args.model_output_dir + '/pass-' + str(pass_id)
if trainer_id == 0:
fluid.io.save_persistables(executor=exe, dirname=model_dir)
with open(model_dir + "/_success", 'w+') as f:
f.write(str(pass_id))
def GetFileList(data_path):
return os.listdir(data_path)
def train(args):
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
filelist = GetFileList(args.train_data_path)
word2vec_reader = None
if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1":
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_path, filelist, 0, 1)
else:
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
word2vec_reader = reader.Word2VecReader(args.dict_path,
args.train_data_path, filelist,
trainer_id, trainer_num)
logger.info("dict_size: {}".format(word2vec_reader.dict_size))
loss, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size,
word2vec_reader.word_frequencys,
args.embedding_size,
args.max_code_length,
args.with_hs,
args.with_nce,
is_sparse=args.is_sparse)
optimizer = None
if args.with_Adam:
optimizer = fluid.optimizer.Adam(learning_rate=1e-3)
else:
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
# do local training
if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1":
logger.info("run local training")
main_program = fluid.default_main_program()
with open("local.main.proto", "w") as f:
f.write(str(main_program))
train_loop(args, main_program, word2vec_reader, py_reader, loss, 0)
# do distribute training
else:
logger.info("run dist training")
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
training_role = os.environ["PADDLE_TRAINING_ROLE"]
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = False
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=True)
if training_role == "PSERVER":
logger.info("run pserver")
prog = t.get_pserver_program(current_endpoint)
startup = t.get_startup_program(
current_endpoint, pserver_program=prog)
with open("pserver.main.proto.{}".format(os.getenv("CUR_PORT")),
"w") as f:
f.write(str(prog))
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif training_role == "TRAINER":
logger.info("run trainer")
train_prog = t.get_trainer_program()
with open("trainer.main.proto.{}".format(trainer_id), "w") as f:
f.write(str(train_prog))
train_loop(args, train_prog, word2vec_reader, py_reader, loss,
trainer_id)
def env_declar():
print("******** Rename Cluster Env to PaddleFluid Env ********")
print("Content-Type: text/plain\n\n")
for key in os.environ.keys():
print("%30s %s \n" % (key, os.environ[key]))
if os.environ["TRAINING_ROLE"] == "PSERVER" or os.environ[
"PADDLE_IS_LOCAL"] == "0":
os.environ["PADDLE_TRAINING_ROLE"] = os.environ["TRAINING_ROLE"]
os.environ["PADDLE_PSERVER_PORT"] = os.environ["PADDLE_PORT"]
os.environ["PADDLE_PSERVER_IPS"] = os.environ["PADDLE_PSERVERS"]
os.environ["PADDLE_TRAINERS"] = os.environ["PADDLE_TRAINERS_NUM"]
os.environ["PADDLE_CURRENT_IP"] = os.environ["POD_IP"]
os.environ["PADDLE_TRAINER_ID"] = os.environ["PADDLE_TRAINER_ID"]
# we set the thread number same as CPU number
os.environ["CPU_NUM"] = "12"
print("Content-Type: text/plain\n\n")
for key in os.environ.keys():
print("%30s %s \n" % (key, os.environ[key]))
print("****** Rename Cluster Env to PaddleFluid Env END ******")
if __name__ == '__main__':
args = parse_args()
if args.is_local:
pass
else:
env_declar()
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册