提交 28b8943b 编写于 作者: J JiabinYang

add async reader and converter

上级 cff7e36b
......@@ -23,24 +23,34 @@ cd data && ./download.sh && cd ..
对数据进行预处理以生成一个词典。
```bash
python preprocess.py --data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled --dict_path data/1-billion_dict
python preprocess.py --data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled --dict_path data/1-billion_dict --is_local
```
如果您想使用我们支持的第三方词汇表,请将--other_dict_path设置为您存放将使用的词汇表的目录,并设置--with_other_dict使用它
如果您希望使用async executor来加速训练,需要先创建一个叫async_data的目录,然后使用以下命令:
```bash
python async_data_converter.py --train_data_path your_train_data_path --dict_path your_dict_path
```
如果您希望使用层次softmax则需要加上--with_hs,这个方法将会在您当前目录下刚刚创建的async_data目录下写入转换好用于async_executor的数据,如果您的数据集很大这个过程可能持续很久
## 训练
训练的命令行选项可以通过`python train.py -h`列出。
### 单机训练:
使用parallel executor
```bash
export CPU_NUM=1
python train.py \
--train_data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled \
--dict_path data/1-billion_dict \
--with_hs --with_nce --is_local \
--with_nce --is_local \
2>&1 | tee train.log
```
使用async executor
```bash
python async_train.py --train_data_path ./async_data/ \
--dict_path data/1-billion_dict --with_nce --with_hs \
--epochs 1 --thread_num 1 --is_sparse --batch_size 100 --is_local 2>&1 | tee async_trainer1.log
```
### 分布式训练
本地启动一个2 trainer 2 pserver的分布式训练任务,分布式场景下训练数据会按照trainer的id进行切分,保证trainer之间的训练数据不会重叠,提高训练效率
......
......@@ -34,20 +34,31 @@ python preprocess.py --data_path ./data/1-billion-word-language-modeling-benchma
if you would like to use our supported third party vocab, please set --other_dict_path as the directory of where you
save the vocab you will use and set --with_other_dict flag on to using it.
If you want to use async executor to speed up training, you need to first create a directory called async_data and then use the following command:
```bash
python async_data_converter.py --train_data_path your_train_data_path --dict_path your_dict_path
```
If you want to use the hierarchical softmax you need to add --with_hs, this method will be written in the async_data directory just created in your current directory to convert the data for async_executor,
If your data set is large, this process may takes long time to finish
## Train
The command line options for training can be listed by `python train.py -h`.
### Local Train:
we set CPU_NUM=1 as default CPU_NUM to execute
with parallel executor
```bash
export CPU_NUM=1 && \
export CPU_NUM=1
python train.py \
--train_data_path ./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled \
--dict_path data/1-billion_dict \
--with_hs --with_nce --is_local \
--with_nce --is_local \
2>&1 | tee train.log
```
with async executor
```bash
python async_train.py --train_data_path ./async_data/ \
--dict_path data/1-billion_dict --with_nce --with_hs \
--epochs 1 --thread_num 1 --is_sparse --batch_size 100 --is_local 2>&1 | tee async_trainer1.log
```
### Distributed Train
Run a 2 pserver 2 trainer distribute training on a single machine.
......
from __future__ import print_function
import argparse
import logging
import os
import time
import numpy as np
import paddle.fluid as fluid
import reader
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Word2vec example")
parser.add_argument(
'--train_data_path',
type=str,
default='./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled',
help="The path of taining dataset")
parser.add_argument(
'--dict_path',
type=str,
default='./data/1-billion_dict',
help="The path of data dict")
parser.add_argument(
'--with_hs',
action='store_true',
required=False,
default=False,
help='using hierarchical sigmoid, (default: False)')
return parser.parse_args()
def GetFileList(data_path):
return os.listdir(data_path)
def converter(args):
filelist = GetFileList(args.train_data_path)
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_path, filelist, 0, 1)
word2vec_reader.async_train(args.with_hs)
if __name__ == "__main__":
args = parse_args()
converter(args)
from __future__ import print_function
import argparse
import logging
import os
import time
import numpy as np
# disable gpu training for this example
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
from paddle.fluid.executor import global_scope
import reader
from network_conf import skip_gram_word2vec
from infer import inference_test
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Word2vec example")
parser.add_argument(
'--train_data_path',
type=str,
default='./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled',
help="The path of training dataset")
parser.add_argument(
'--dict_path',
type=str,
default='./data/1-billion_dict',
help="The path of data dict")
parser.add_argument(
'--test_data_path',
type=str,
default='./data/text8',
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=100,
help="The size of mini-batch (default:100)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
parser.add_argument(
'--thread_num', type=int, default=1, help='training thread num')
parser.add_argument(
'--embedding_size',
type=int,
default=64,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--with_hs',
action='store_true',
required=False,
default=False,
help='using hierarchical sigmoid, (default: False)')
parser.add_argument(
'--with_nce',
action='store_true',
required=False,
default=False,
help='using negtive sampling, (default: True)')
parser.add_argument(
'--max_code_length',
type=int,
default=40,
help='max code length used by hierarchical sigmoid, (default: 40)')
parser.add_argument(
'--is_sparse',
action='store_true',
required=False,
default=False,
help='embedding and nce will use sparse or not, (default: False)')
parser.add_argument(
'--with_Adam',
action='store_true',
required=False,
default=False,
help='Using Adam as optimizer or not, (default: False)')
parser.add_argument(
'--is_local',
action='store_true',
required=False,
default=False,
help='Local train or not, (default: False)')
parser.add_argument(
'--with_speed',
action='store_true',
required=False,
default=False,
help='print speed or not , (default: False)')
parser.add_argument(
'--with_infer_test',
action='store_true',
required=False,
default=False,
help='Do inference every 100 batches , (default: False)')
parser.add_argument(
'--rank_num',
type=int,
default=4,
help="find rank_num-nearest result for test (default: 4)")
parser.add_argument(
'--use_pyreader',
required=False,
default=False,
help='Whether you want to use pyreader, (default: False)')
parser.add_argument(
'--epochs',
type=int,
required=False,
default=False,
help='training epochs')
return parser.parse_args()
def async_train_loop(args, train_program, dataset, loss, thread_num):
logger.info("run async_train_loop")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
async_executor = fluid.AsyncExecutor(place)
files = [
"%s/%s" % (args.train_data_path, filename)
for filename in os.listdir(args.train_data_path)
]
logger.info("files:" + str(files))
filelist = files
print("filelist:" + str(filelist))
fout = open("main_program.prototxt", "w")
fout.write(str(fluid.default_main_program()))
fout.close()
for i in range(args.epochs):
epoch_start = time.time()
async_executor.run(train_program,
dataset,
filelist,
thread_num, [loss],
debug=False)
epoch_stop = time.time()
run_time = epoch_stop - epoch_start
lines = len(filelist) * 1000000.0
print("run epoch%d done, lines=%s, time=%d, sample/second=%s" %
(i + 1, lines, run_time, lines / run_time))
epoch_model = "word2vec_model/epoch" + str(i + 1)
fluid.io.save_inference_model(epoch_model, [], [loss], exe)
def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train((args.with_hs or (not args.with_nce))),
buf_size=args.batch_size * 100),
batch_size=args.batch_size)
py_reader.decorate_paddle_reader(train_reader)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exec_strategy = fluid.ExecutionStrategy()
print("CPU_NUM:" + str(os.getenv("CPU_NUM")))
exec_strategy.num_threads = int(os.getenv("CPU_NUM"))
build_strategy = fluid.BuildStrategy()
if int(os.getenv("CPU_NUM")) > 1:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
train_exe = fluid.ParallelExecutor(
use_cuda=False,
loss_name=loss.name,
main_program=train_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
profile_state = "CPU"
profiler_step = 0
profiler_step_start = 20
profiler_step_end = 30
for pass_id in range(args.num_passes):
epoch_start = time.time()
py_reader.start()
batch_id = 0
start = time.clock()
try:
while True:
if profiler_step == profiler_step_start:
fluid.profiler.start_profiler(profile_state)
loss_val = train_exe.run(fetch_list=[loss.name])
loss_val = np.mean(loss_val)
if profiler_step == profiler_step_end:
fluid.profiler.stop_profiler('total', 'trainer_profile.log')
profiler_step += 1
else:
profiler_step += 1
if batch_id % 50 == 0:
logger.info(
"TRAIN --> pass: {} batch: {} loss: {} reader queue:{}".
format(pass_id, batch_id,
loss_val.mean() / args.batch_size,
py_reader.queue.size()))
if args.with_speed:
if batch_id % 1000 == 0 and batch_id != 0:
elapsed = (time.clock() - start)
start = time.clock()
samples = 1001 * args.batch_size * int(
os.getenv("CPU_NUM"))
logger.info("Time used: {}, Samples/Sec: {}".format(
elapsed, samples / elapsed))
# calculate infer result each 100 batches when using --with_infer_test
if args.with_infer_test:
if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(
batch_id)
inference_test(global_scope(), model_dir, args)
if batch_id % 500000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(
batch_id)
fluid.io.save_persistables(executor=exe, dirname=model_dir)
with open(model_dir + "/_success", 'w+') as f:
f.write(str(batch_id))
batch_id += 1
except fluid.core.EOFException:
py_reader.reset()
epoch_end = time.time()
logger.info("Epoch: {0}, Train total expend: {1} ".format(
pass_id, epoch_end - epoch_start))
model_dir = args.model_output_dir + '/pass-' + str(pass_id)
if trainer_id == 0:
fluid.io.save_persistables(executor=exe, dirname=model_dir)
with open(model_dir + "/_success", 'w+') as f:
f.write(str(pass_id))
def GetFileList(data_path):
return os.listdir(data_path)
def async_train(args):
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
if not args.is_local and os.environ["PADDLE_TRAINING_ROLE"] == "PSERVER":
filelist = []
else:
filelist = GetFileList(args.train_data_path)
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_path, filelist, 0, 1)
loss, words, pyreader = skip_gram_word2vec(
word2vec_reader.dict_size,
word2vec_reader.word_frequencys,
args.embedding_size,
args.max_code_length,
args.with_hs,
args.with_nce,
is_sparse=args.is_sparse)
dataset = fluid.DataFeedDesc('word2vec_with_hs.proto')
dataset.set_batch_size(args.batch_size)
dataset.set_use_slots([w.name for w in words])
optimizer = fluid.optimizer.SGD(learning_rate=1e4)
optimizer.minimize(loss)
# do local training
if args.is_local:
logger.info("run local training")
main_program = fluid.default_main_program()
with open("local.main.proto", "w") as f:
f.write(str(main_program))
async_train_loop(args,
fluid.default_main_program(), dataset, loss,
args.thread_num)
# do distribute training
else:
logger.info("run dist training")
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
training_role = os.environ["PADDLE_TRAINING_ROLE"]
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = False
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=False)
if training_role == "PSERVER":
logger.info("run pserver")
prog = t.get_pserver_program(current_endpoint)
startup = t.get_startup_program(
current_endpoint, pserver_program=prog)
with open("pserver.main.proto.{}".format(os.getenv("CUR_PORT")),
"w") as f:
f.write(str(prog))
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif training_role == "TRAINER":
logger.info("run trainer")
train_prog = t.get_trainer_program()
with open("trainer.main.proto.{}".format(trainer_id), "w") as f:
f.write(str(train_prog))
async_train_loop(args, train_prog, dataset, loss, args.thread_num)
def train(args):
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
filelist = GetFileList(args.train_data_path)
word2vec_reader = None
if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1":
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_path, filelist, 0, 1)
else:
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainer_num = int(os.environ["PADDLE_TRAINERS"])
word2vec_reader = reader.Word2VecReader(args.dict_path,
args.train_data_path, filelist,
trainer_id, trainer_num)
logger.info("dict_size: {}".format(word2vec_reader.dict_size))
loss, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size,
word2vec_reader.word_frequencys,
args.embedding_size,
args.max_code_length,
args.with_hs,
args.with_nce,
is_sparse=args.is_sparse)
optimizer = None
if args.with_Adam:
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
else:
optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
optimizer.minimize(loss)
# do local training
if args.is_local or os.getenv("PADDLE_IS_LOCAL", "1") == "1":
logger.info("run local training")
main_program = fluid.default_main_program()
with open("local.main.proto", "w") as f:
f.write(str(main_program))
train_loop(args, main_program, word2vec_reader, py_reader, loss, 0)
# do distribute training
else:
logger.info("run dist training")
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
training_role = os.environ["PADDLE_TRAINING_ROLE"]
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
config = fluid.DistributeTranspilerConfig()
#config.slice_var_up = False
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=True)
if training_role == "PSERVER":
logger.info("run pserver")
prog = t.get_pserver_program(current_endpoint)
startup = t.get_startup_program(
current_endpoint, pserver_program=prog)
with open("pserver.main.proto.{}".format(os.getenv("CUR_PORT")),
"w") as f:
f.write(str(prog))
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif training_role == "TRAINER":
logger.info("run trainer")
train_prog = t.get_trainer_program()
with open("trainer.main.proto.{}".format(trainer_id), "w") as f:
f.write(str(train_prog))
train_loop(args, train_prog, word2vec_reader, py_reader, loss,
trainer_id)
def env_declar():
print("******** Rename Cluster Env to PaddleFluid Env ********")
'''
print("Content-Type: text/plain\n\n")
for key in os.environ.keys():
print("%30s %s \n" % (key, os.environ[key]))
'''
if os.environ["TRAINING_ROLE"] == "PSERVER" or os.environ[
"PADDLE_IS_LOCAL"] == "0":
os.environ["PADDLE_TRAINING_ROLE"] = os.environ["TRAINING_ROLE"]
os.environ["PADDLE_PSERVER_PORT"] = os.environ["PADDLE_PORT"]
os.environ["PADDLE_PSERVER_IPS"] = os.environ["PADDLE_PSERVERS"]
os.environ["PADDLE_TRAINERS"] = os.environ["PADDLE_TRAINERS_NUM"]
os.environ["PADDLE_CURRENT_IP"] = os.environ["POD_IP"]
os.environ["PADDLE_TRAINER_ID"] = os.environ["PADDLE_TRAINER_ID"]
# we set the thread number same as CPU number
os.environ["CPU_NUM"] = "12"
'''
print("Content-Type: text/plain\n\n")
for key in os.environ.keys():
print("%30s %s \n" % (key, os.environ[key]))
print("****** Rename Cluster Env to PaddleFluid Env END ******")
'''
if __name__ == '__main__':
args = parse_args()
logger.info(args)
if args.is_local:
pass
else:
#env_declar()
pass
#train(args)
async_train(args)
......@@ -110,29 +110,29 @@ def build_test_case_from_file(args, emb):
def build_small_test_case(emb):
emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[
'aunt']]
desc1 = "boy - girl + aunt = uncle"
label1 = word_to_id["uncle"]
emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[
word_to_id['sisters']]
desc2 = "brother - sister + sisters = brothers"
label2 = word_to_id["brothers"]
emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[
'woman']]
desc3 = "king - queen + woman = man"
label3 = word_to_id["man"]
emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[
word_to_id['slowly']]
desc4 = "reluctant - reluctantly + slowly = slow"
label4 = word_to_id["slow"]
emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[
'deeper']]
desc5 = "old - older + deeper = deep"
label5 = word_to_id["deep"]
# emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[
# 'aunt']]
# desc1 = "boy - girl + aunt = uncle"
# label1 = word_to_id["uncle"]
# emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[
# word_to_id['sisters']]
# desc2 = "brother - sister + sisters = brothers"
# label2 = word_to_id["brothers"]
# emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[
# 'woman']]
# desc3 = "king - queen + woman = man"
# label3 = word_to_id["man"]
# emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[
# word_to_id['slowly']]
# desc4 = "reluctant - reluctantly + slowly = slow"
# label4 = word_to_id["slow"]
# emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[
# 'deeper']]
# desc5 = "old - older + deeper = deep"
# label5 = word_to_id["deep"]
emb6 = emb[word_to_id['boy']]
desc6 = "boy"
emb6 = emb[word_to_id['father']]
desc6 = "father"
label6 = word_to_id["boy"]
emb7 = emb[word_to_id['king']]
desc7 = "king"
......@@ -143,13 +143,16 @@ def build_small_test_case(emb):
emb9 = emb[word_to_id['key']]
desc9 = "key"
label9 = word_to_id["key"]
test_cases = [emb1, emb2, emb3, emb4, emb5, emb6, emb7, emb8, emb9]
test_case_desc = [
desc1, desc2, desc3, desc4, desc5, desc6, desc7, desc8, desc9
]
test_labels = [
label1, label2, label3, label4, label5, label6, label7, label8, label9
]
# test_cases = [emb1, emb2, emb3, emb4, emb5, emb6, emb7, emb8, emb9]
# test_case_desc = [
# desc1, desc2, desc3, desc4, desc5, desc6, desc7, desc8, desc9
# ]
# test_labels = [
# label1, label2, label3, label4, label5, label6, label7, label8, label9
# ]
test_cases = [emb6, emb7, emb8, emb9]
test_case_desc = [desc6, desc7, desc8, desc9]
test_labels = [label6, label7, label8, label9]
return norm(np.array(test_cases)), test_case_desc, test_labels
......
......@@ -29,7 +29,8 @@ def skip_gram_word2vec(dict_size,
max_code_length=None,
with_hsigmoid=False,
with_nce=True,
is_sparse=False):
is_sparse=False,
use_pyreader=False):
def nce_layer(input, label, embedding_size, num_total_classes,
num_neg_samples, sampler, word_frequencys, sample_weight):
......@@ -73,9 +74,8 @@ def skip_gram_word2vec(dict_size,
datas = []
input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64')
predict_word = fluid.layers.data(
name='predict_word', shape=[1], dtype='int64')
input_word = fluid.layers.data(name="target", shape=[1], dtype='int64')
predict_word = fluid.layers.data(name='context', shape=[1], dtype='int64')
datas.append(input_word)
datas.append(predict_word)
......@@ -91,10 +91,19 @@ def skip_gram_word2vec(dict_size,
datas.append(path_table)
datas.append(path_code)
py_reader = fluid.layers.create_py_reader_by_data(
capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True)
py_reader = None
words = None
if use_pyreader:
py_reader = fluid.layers.create_py_reader_by_data(
capacity=64,
feed_list=datas,
name='py_reader',
use_double_buffer=True)
words = fluid.layers.read_file(py_reader)
else:
words = datas
words = fluid.layers.read_file(py_reader)
target_emb = fluid.layers.embedding(
input=words[0],
is_sparse=is_sparse,
......@@ -126,4 +135,4 @@ def skip_gram_word2vec(dict_size,
avg_cost = fluid.layers.reduce_mean(cost)
return avg_cost, py_reader
return avg_cost, words, py_reader
......@@ -203,26 +203,45 @@ def preprocess(args):
for line in f:
word_count[native_to_unicode(line.strip())] = 1
# if args.is_local:
# for i in range(1, 100):
# with io.open(
# args.data_path + "/news.en-000{:0>2d}-of-00100".format(i),
# encoding='utf-8') as f:
# for line in f:
# line = strip_lines(line)
# words = line.split()
# if args.with_other_dict:
# for item in words:
# if item in word_count:
# word_count[item] = word_count[item] + 1
# else:
# word_count[native_to_unicode('<UNK>')] += 1
# else:
# for item in words:
# if item in word_count:
# word_count[item] = word_count[item] + 1
# else:
# word_count[item] = 1
if args.is_local:
for i in range(1, 100):
with io.open(
args.data_path + "/news.en-000{:0>2d}-of-00100".format(i),
encoding='utf-8') as f:
for line in f:
with io.open(args.data_path + "/text8", encoding='utf-8') as f:
for line in f:
if args.with_other_dict:
line = strip_lines(line)
words = line.split()
if args.with_other_dict:
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[native_to_unicode('<UNK>')] += 1
else:
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[native_to_unicode('<UNK>')] += 1
else:
line = text_strip(line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= args.freq:
......@@ -230,6 +249,7 @@ def preprocess(args):
for item in item_to_remove:
del word_count[item]
print(word_count)
path_table, path_code, word_code_len = build_Huffman(word_count, 40)
with io.open(args.dict_path, 'w+', encoding='utf-8') as f:
......
......@@ -165,6 +165,101 @@ class Word2VecReader(object):
else:
return _reader_hs
def async_train(self, with_hs):
def _reader():
write_f = list()
for i in range(20):
write_f.append(
io.open(
"./async_data/async_" + str(i), 'w+', encoding='utf-8'))
for file in self.filelist:
with io.open(
self.data_path_ + "/" + file, 'r',
encoding='utf-8') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
file_spilt_count = 0
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.strip_lines(line, self.word_count)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx)
for context_id in context_word_ids:
content = "1" + " " + str(
target_id) + " " + "1" + " " + str(
context_id) + '\n'
write_f[file_spilt_count %
20].write(content.decode('utf-8'))
file_spilt_count += 1
else:
pass
count += 1
for i in range(20):
write_f[i].close()
def _reader_hs():
write_f = list()
for i in range(20):
write_f.append(
io.open(
"./async_data/async_" + str(i), 'w+', encoding='utf-8'))
for file in self.filelist:
with io.open(
self.data_path_ + "/" + file, 'r',
encoding='utf-8') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
file_spilt_count = 0
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.strip_lines(line, self.word_count)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx)
for context_id in context_word_ids:
path = [
str(i)
for i in self.word_to_path[
self.id_to_word[target_id]]
]
code = [
str(j)
for j in self.word_to_code[
self.id_to_word[target_id]]
]
content = str(1) + " " + str(
target_id
) + " " + str(1) + " " + str(
context_id
) + " " + str(len(path)) + " " + ' '.join(
path) + " " + str(len(
code)) + " " + ' '.join(code) + '\n'
write_f[file_spilt_count %
20].write(content.decode('utf-8'))
file_spilt_count += 1
else:
pass
count += 1
for i in range(20):
write_f[i].close()
if not with_hs:
_reader()
else:
_reader_hs()
if __name__ == "__main__":
window_size = 5
......
......@@ -196,7 +196,7 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
time.sleep(10)
epoch_start = time.time()
batch_id = 0
start = time.clock()
start = time.time()
try:
while True:
......@@ -211,8 +211,8 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
loss_val.mean(), py_reader.queue.size()))
if args.with_speed:
if batch_id % 1000 == 0 and batch_id != 0:
elapsed = (time.clock() - start)
start = time.clock()
elapsed = (time.time() - start)
start = time.time()
samples = 1001 * args.batch_size * int(
os.getenv("CPU_NUM"))
logger.info("Time used: {}, Samples/Sec: {}".format(
......@@ -261,20 +261,21 @@ def train(args):
args.dict_path, args.train_data_path, filelist, 0, 1)
else:
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
trainer_num = int(os.environ["PADDLE_TRAINERS"])
word2vec_reader = reader.Word2VecReader(args.dict_path,
args.train_data_path, filelist,
trainer_id, trainer_num)
logger.info("dict_size: {}".format(word2vec_reader.dict_size))
loss, py_reader = skip_gram_word2vec(
loss, words, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size,
word2vec_reader.word_frequencys,
args.embedding_size,
args.max_code_length,
args.with_hs,
args.with_nce,
is_sparse=args.is_sparse)
is_sparse=args.is_sparse,
use_pyreader=True)
optimizer = None
if args.with_Adam:
......
name : "MultiSlotDataFeed" batch_size : 32 multi_slot_desc {
slots {
name:
"target" type : "uint64" is_dense : false is_used : false
}
slots {
name:
"context" type : "uint64" is_dense : false is_used : false
}
slots {
name:
"path_table" type : "uint64" is_dense : false is_used : false
}
slots {
name:
"path_code" type : "uint64" is_dense : false is_used : false
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册