提交 d5b4ffd9 编写于 作者: C chengmo

Merge branch 'develop' into 'tdm_infer'

# Conflicts:
#   fleet_rec/core/trainers/single_trainer.py
#   fleet_rec/core/trainers/transpiler_trainer.py
......@@ -46,6 +46,7 @@ class ClusterTrainer(TranspileTrainer):
else:
self.regist_context_processor(
'train_pass', self.dataloader_train)
self.regist_context_processor('infer_pass', self.infer)
self.regist_context_processor('terminal_pass', self.terminal)
def build_strategy(self):
......@@ -139,14 +140,15 @@ class ClusterTrainer(TranspileTrainer):
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 10 == 0 and batch_id != 0:
if batch_id % self.fetch_period == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
self.save(epoch, "train", is_fleet=True)
fleet.stop_worker()
context['status'] = 'terminal_pass'
context['status'] = 'infer_pass'
def dataset_train(self, context):
fleet.init_worker()
......@@ -162,10 +164,7 @@ class ClusterTrainer(TranspileTrainer):
print_period=self.fetch_period)
self.save(i, "train", is_fleet=True)
fleet.stop_worker()
context['status'] = 'terminal_pass'
def infer(self, context):
context['status'] = 'terminal_pass'
context['status'] = 'infer_pass'
def terminal(self, context):
for model in self.increment_models:
......
......@@ -115,58 +115,6 @@ class SingleTrainer(TranspileTrainer):
self.save(i, "train", is_fleet=False)
context['status'] = 'infer_pass'
def infer(self, context):
logger.info("Run in infer pass")
infer_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(infer_program, startup_program):
self.model.infer_net()
logger.info("End build infer net")
if self.model._infer_data_loader is None:
context['status'] = 'terminal_pass'
return
reader = self._get_dataloader("Evaluate")
logger.info("End Get data loader")
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_infer_results().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
self._exe.run(startup_program)
for (epoch, model_dir) in self.increment_models:
logger.info(
"Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir))
program = infer_program.clone()
fluid.io.load_persistables(self._exe, model_dir, program)
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 2 == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
context['status'] = 'terminal_pass'
def terminal(self, context):
for model in self.increment_models:
print("epoch :{}, dir: {}".format(model[0], model[1]))
......
......@@ -36,7 +36,7 @@ class TranspileTrainer(Trainer):
def processor_register(self):
print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first")
def _get_dataloader(self, state):
def _get_dataloader(self, state="TRAIN"):
if state == "TRAIN":
dataloader = self.model._data_loader
namespace = "train.reader"
......@@ -60,7 +60,7 @@ class TranspileTrainer(Trainer):
dataloader.set_sample_generator(reader, batch_size)
return dataloader
def _get_dataset(self, state):
def _get_dataset(self, state="TRAIN"):
if state == "TRAIN":
inputs = self.model.get_inputs()
namespace = "train.reader"
......@@ -115,24 +115,23 @@ class TranspileTrainer(Trainer):
if not need_save(epoch_id, save_interval, False):
return
# print("save inference model is not supported now.")
# return
print("save inference model is not supported now.")
return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
if feed_varnames is None or fetch_varnames is None:
return
feed_varnames = envs.get_global_env(
"save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env(
"save.inference.fetch_varnames", None, namespace)
fetch_vars = [fluid.default_main_program().global_block().vars[varname]
for varname in fetch_varnames]
dirname = envs.get_global_env(
"save.inference.dirname", None, namespace)
fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None, namespace)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_inference_model(dirname, feed_varnames, fetch_vars)
fleet.save_inference_model(self._exe, dirname, feed_varnames, fetch_vars)
else:
fluid.io.save_inference_model(
dirname, feed_varnames, fetch_vars, self._exe)
......@@ -179,7 +178,53 @@ class TranspileTrainer(Trainer):
context['is_exit'] = True
def infer(self, context):
context['is_exit'] = True
infer_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(infer_program, startup_program):
self.model.infer_net()
if self.model._infer_data_loader is None:
context['status'] = 'terminal_pass'
return
reader = self._get_dataloader("Evaluate")
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_infer_results().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
self._exe.run(startup_program)
for (epoch, model_dir) in self.increment_models:
print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir))
program = infer_program.clone()
fluid.io.load_persistables(self._exe, model_dir, program)
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 2 == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
context['status'] = 'terminal_pass'
def terminal(self, context):
print("clean up and exit")
......
......@@ -72,8 +72,11 @@ def set_runtime_envs(cluster_envs, engine_yaml):
if cluster_envs is None:
cluster_envs = {}
engine_extras = get_engine_extras()
if "train.trainer.threads" in engine_extras and "CPU_NUM" in cluster_envs:
cluster_envs["CPU_NUM"] = engine_extras["train.trainer.threads"]
envs.set_runtime_environs(cluster_envs)
envs.set_runtime_environs(get_engine_extras())
envs.set_runtime_environs(engine_extras)
need_print = {}
for k, v in os.environ.items():
......
#! /bin/bash
# download train_data
mkdir raw_data
wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar
mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ raw_data/
# preprocess data
python preprocess.py --build_dict --build_dict_corpus_dir raw_data/training-monolingual.tokenized.shuffled --dict_path raw_data/test_build_dict
python preprocess.py --filter_corpus --dict_path raw_data/test_build_dict --input_corpus_dir raw_data/training-monolingual.tokenized.shuffled --output_corpus_dir raw_data/convert_text8 --min_count 5 --downsample 0.001
mkdir thirdparty
mv raw_data/test_build_dict thirdparty/
mv raw_data/test_build_dict_word_to_id_ thirdparty/
python preprocess.py --data_resplit --input_corpus_dir=raw_data/convert_text8 --output_corpus_dir=train_data
# download test data
wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
tar xzvf test_dir.tar -C raw_data
mv raw_data/data/test_dir test_data/
rm -rf raw_data
# -*- coding: utf-8 -*
import os
import random
import re
import six
import argparse
import io
import math
prog = re.compile("[^a-z ]", flags=0)
def parse_args():
parser = argparse.ArgumentParser(
description="Paddle Fluid word2 vector preprocess")
parser.add_argument(
'--build_dict_corpus_dir', type=str, help="The dir of corpus")
parser.add_argument(
'--input_corpus_dir', type=str, help="The dir of input corpus")
parser.add_argument(
'--output_corpus_dir', type=str, help="The dir of output corpus")
parser.add_argument(
'--dict_path',
type=str,
default='./dict',
help="The path of dictionary ")
parser.add_argument(
'--min_count',
type=int,
default=5,
help="If the word count is less then min_count, it will be removed from dict"
)
parser.add_argument(
'--file_nums',
type=int,
default=1024,
help="re-split input corpus file nums"
)
parser.add_argument(
'--downsample',
type=float,
default=0.001,
help="filter word by downsample")
parser.add_argument(
'--filter_corpus',
action='store_true',
default=False,
help='Filter corpus')
parser.add_argument(
'--build_dict',
action='store_true',
default=False,
help='Build dict from corpus')
parser.add_argument(
'--data_resplit',
action='store_true',
default=False,
help='re-split input corpus files')
return parser.parse_args()
def text_strip(text):
#English Preprocess Rule
return prog.sub("", text.lower())
# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py
# Unicode utility functions that work with Python 2 and 3
def native_to_unicode(s):
if _is_unicode(s):
return s
try:
return _to_unicode(s)
except UnicodeDecodeError:
res = _to_unicode(s, ignore_errors=True)
return res
def _is_unicode(s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def _to_unicode(s, ignore_errors=False):
if _is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def filter_corpus(args):
"""
filter corpus and convert id.
"""
word_count = dict()
word_to_id_ = dict()
word_all_count = 0
id_counts = []
word_id = 0
#read dict
with io.open(args.dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, count = line.split()[0], int(line.split()[1])
word_count[word] = count
word_to_id_[word] = word_id
word_id += 1
id_counts.append(count)
word_all_count += count
#write word2id file
print("write word2id file to : " + args.dict_path + "_word_to_id_")
with io.open(
args.dict_path + "_word_to_id_", 'w+', encoding='utf-8') as fid:
for k, v in word_to_id_.items():
fid.write(k + " " + str(v) + '\n')
#filter corpus and convert id
if not os.path.exists(args.output_corpus_dir):
os.makedirs(args.output_corpus_dir)
for file in os.listdir(args.input_corpus_dir):
with io.open(args.output_corpus_dir + '/convert_' + file + '.csv', "w") as wf:
with io.open(
args.input_corpus_dir + '/' + file, encoding='utf-8') as rf:
print(args.input_corpus_dir + '/' + file)
for line in rf:
signal = False
line = text_strip(line)
words = line.split()
write_line = ""
for item in words:
if item in word_count:
idx = word_to_id_[item]
else:
idx = word_to_id_[native_to_unicode('<UNK>')]
count_w = id_counts[idx]
corpus_size = word_all_count
keep_prob = (
math.sqrt(count_w /
(args.downsample * corpus_size)) + 1
) * (args.downsample * corpus_size) / count_w
r_value = random.random()
if r_value > keep_prob:
continue
write_line += str(idx)
write_line += ","
signal = True
if signal:
write_line = write_line[:-1] + "\n"
wf.write(_to_unicode(write_line))
def build_dict(args):
"""
proprocess the data, generate dictionary and save into dict_path.
:param corpus_dir: the input data dir.
:param dict_path: the generated dict path. the data in dict is "word count"
:param min_count:
:return:
"""
# word to count
word_count = dict()
for file in os.listdir(args.build_dict_corpus_dir):
with io.open(
args.build_dict_corpus_dir + "/" + file, encoding='utf-8') as f:
print("build dict : ", args.build_dict_corpus_dir + "/" + file)
for line in f:
line = text_strip(line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= args.min_count:
item_to_remove.append(item)
unk_sum = 0
for item in item_to_remove:
unk_sum += word_count[item]
del word_count[item]
#sort by count
word_count[native_to_unicode('<UNK>')] = unk_sum
word_count = sorted(
word_count.items(), key=lambda word_count: -word_count[1])
with io.open(args.dict_path, 'w+', encoding='utf-8') as f:
for k, v in word_count:
f.write(k + " " + str(v) + '\n')
def data_split(args):
raw_data_dir = args.input_corpus_dir
new_data_dir = args.output_corpus_dir
if not os.path.exists(new_data_dir):
os.mkdir(new_data_dir)
files = os.listdir(raw_data_dir)
print(files)
index = 0
contents = []
for file_ in files:
with open(os.path.join(raw_data_dir, file_), 'r') as f:
contents.extend(f.readlines())
num = int(args.file_nums)
lines_per_file = len(contents) / num
print("contents: ", str(len(contents)))
print("lines_per_file: ", str(lines_per_file))
for i in range(1, num+1):
with open(os.path.join(new_data_dir, "part_" + str(i)), 'w') as fout:
data = contents[(i-1)*lines_per_file:min(i*lines_per_file,len(contents))]
for line in data:
fout.write(line)
if __name__ == "__main__":
args = parse_args()
if args.build_dict:
build_dict(args)
elif args.filter_corpus:
filter_corpus(args)
elif args.data_resplit:
data_split(args)
else:
print(
"error command line, please choose --build_dict or --filter_corpus")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册