diff --git a/fluid/face_detection/.gitignore b/fluid/face_detection/.gitignore index ea3e7b052591ddb7d19525a685c13971bededf6f..0636bd5b2995e0a0fa27fe54be6ccbbb78074dca 100644 --- a/fluid/face_detection/.gitignore +++ b/fluid/face_detection/.gitignore @@ -9,3 +9,4 @@ log* output* pred eval_tools +box* diff --git a/fluid/image_classification/caffe2fluid/kaffe/graph.py b/fluid/image_classification/caffe2fluid/kaffe/graph.py index baea3cc1dc9431d07d0d3ca7191a429d1ef0f398..b342ce4dcb3dda6e38597457839ba376c9a27354 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/graph.py +++ b/fluid/image_classification/caffe2fluid/kaffe/graph.py @@ -122,8 +122,8 @@ class Graph(object): def compute_output_shapes(self): sorted_nodes = self.topologically_sorted() for node in sorted_nodes: - node.output_shape = make_tensor( - *NodeKind.compute_output_shape(node)) + node.output_shape = make_tensor(*NodeKind.compute_output_shape( + node)) def replaced(self, new_nodes): return Graph(nodes=new_nodes, name=self.name, trace=self.output_trace) diff --git a/fluid/mnist/.run_ce.sh b/fluid/mnist/.run_ce.sh new file mode 100755 index 0000000000000000000000000000000000000000..d6ccf429b52da1ff26ac02df5af287461a823a98 --- /dev/null +++ b/fluid/mnist/.run_ce.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +# This file is only used for continuous evaluation. + +rm -rf *_factor.txt +model_file='model.py' +python $model_file --batch_size 128 --pass_num 5 --device CPU | python _ce.py diff --git a/fluid/mnist/_ce.py b/fluid/mnist/_ce.py new file mode 100644 index 0000000000000000000000000000000000000000..8767e5692ea02d6e181cde714d4c53feba05e3ab --- /dev/null +++ b/fluid/mnist/_ce.py @@ -0,0 +1,61 @@ +# this file is only used for continuous evaluation test! + +import os +import sys +sys.path.append(os.environ['ceroot']) +from kpi import CostKpi, DurationKpi, AccKpi + +# NOTE kpi.py should shared in models in some way!!!! + +train_cost_kpi = CostKpi('train_cost', 0.02, actived=True) +test_acc_kpi = AccKpi('test_acc', 0.005, actived=True) +train_duration_kpi = DurationKpi('train_duration', 0.06, actived=True) +train_acc_kpi = AccKpi('train_acc', 0.005, actived=True) + +tracking_kpis = [ + train_acc_kpi, + train_cost_kpi, + test_acc_kpi, + train_duration_kpi, +] + +def parse_log(log): + ''' + This method should be implemented by model developers. + + The suggestion: + + each line in the log should be key, value, for example: + + " + train_cost\t1.0 + test_cost\t1.0 + train_cost\t1.0 + train_cost\t1.0 + train_acc\t1.2 + " + ''' + for line in log.split('\n'): + fs = line.strip().split('\t') + print (fs) + if len(fs) == 3 and fs[0] == 'kpis': + kpi_name = fs[1] + kpi_value = float(fs[2]) + yield kpi_name, kpi_value + + +def log_to_ce(log): + kpi_tracker = {} + for kpi in tracking_kpis: + kpi_tracker[kpi.name] = kpi + + for (kpi_name, kpi_value) in parse_log(log): + print (kpi_name, kpi_value) + kpi_tracker[kpi_name].add_record(kpi_value) + kpi_tracker[kpi_name].persist() + + +if __name__ == '__main__': + log = sys.stdin.read() + log_to_ce(log) + diff --git a/fluid/mnist/model.py b/fluid/mnist/model.py new file mode 100644 index 0000000000000000000000000000000000000000..aa470e86d2f555a7cf778793de4cf47fc4841868 --- /dev/null +++ b/fluid/mnist/model.py @@ -0,0 +1,199 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import argparse +import time + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler + +SEED = 90 +DTYPE = "float32" + +# random seed must set before configuring the network. +fluid.default_startup_program().random_seed = SEED + + +def parse_args(): + parser = argparse.ArgumentParser("mnist model benchmark.") + parser.add_argument( + '--batch_size', type=int, default=128, help='The minibatch size.') + parser.add_argument( + '--iterations', type=int, default=35, help='The number of minibatches.') + parser.add_argument( + '--pass_num', type=int, default=5, help='The number of passes.') + parser.add_argument( + '--device', + type=str, + default='GPU', + choices=['CPU', 'GPU'], + help='The device type.') + parser.add_argument( + '--infer_only', action='store_true', help='If set, run forward only.') + parser.add_argument( + '--use_cprof', action='store_true', help='If set, use cProfile.') + parser.add_argument( + '--use_nvprof', + action='store_true', + help='If set, use nvprof for CUDA.') + args = parser.parse_args() + return args + + +def print_arguments(args): + vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and + vars(args)['device'] == 'GPU') + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).iteritems()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +def cnn_model(data): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=data, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + + # TODO(dzhwinter) : refine the initializer and random seed settting + SIZE = 10 + input_shape = conv_pool_2.shape + param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE] + scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5 + + predict = fluid.layers.fc( + input=conv_pool_2, + size=SIZE, + act="softmax", + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale))) + return predict + + +def eval_test(exe, batch_acc, batch_size_tensor, inference_program): + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=args.batch_size) + test_pass_acc = fluid.average.WeightedAverage() + for batch_id, data in enumerate(test_reader()): + img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]), + data)).astype(DTYPE) + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + y_data = y_data.reshape([len(y_data), 1]) + + acc, weight = exe.run(inference_program, + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[batch_acc, batch_size_tensor]) + test_pass_acc.add(value=acc, weight=weight) + pass_acc = test_pass_acc.eval() + return pass_acc + + +def run_benchmark(model, args): + if args.use_cprof: + pr = cProfile.Profile() + pr.enable() + start_time = time.time() + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + # inference program + inference_program = fluid.default_main_program().clone() + with fluid.program_guard(inference_program): + inference_program = fluid.io.get_inference_program( + target_vars=[batch_acc, batch_size_tensor]) + + # Optimization + opt = fluid.optimizer.AdamOptimizer( + learning_rate=0.001, beta1=0.9, beta2=0.999) + opt.minimize(avg_cost) + + fluid.memory_optimize(fluid.default_main_program()) + + # Initialize executor + place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + # Parameter initialization + exe.run(fluid.default_startup_program()) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=args.batch_size) + + accuracy = fluid.average.WeightedAverage() + for pass_id in range(args.pass_num): + accuracy.reset() + pass_start = time.time() + every_pass_loss = [] + for batch_id, data in enumerate(train_reader()): + img_data = np.array( + map(lambda x: x[0].reshape([1, 28, 28]), data)).astype(DTYPE) + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + y_data = y_data.reshape([len(y_data), 1]) + + start = time.time() + loss, acc, weight = exe.run( + fluid.default_main_program(), + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[avg_cost, batch_acc, batch_size_tensor] + ) # The accuracy is the accumulation of batches, but not the current batch. + end = time.time() + accuracy.add(value=acc, weight=weight) + every_pass_loss.append(loss) + print("Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" % + (pass_id, batch_id, loss, acc)) + + pass_end = time.time() + + train_avg_acc = accuracy.eval() + train_avg_loss = np.mean(every_pass_loss) + test_avg_acc = eval_test(exe, batch_acc, batch_size_tensor, + inference_program) + + print( + "pass=%d, train_avg_acc=%f,train_avg_loss=%f, test_avg_acc=%f, elapse=%f" + % (pass_id, train_avg_acc, train_avg_loss, test_avg_acc, + (pass_end - pass_start))) + #Note: The following logs are special for CE monitoring. + #Other situations do not need to care about these logs. + print ("kpis train_acc %f" % train_avg_acc) + print ("kpis train_cost %f" % train_avg_loss) + print ("kpis test_acc %f" % test_avg_acc) + print ("kpis train_duration %f" % (pass_end - pass_start)) + + +if __name__ == '__main__': + args = parse_args() + print_arguments(args) + if args.use_nvprof and args.device == 'GPU': + with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof: + run_benchmark(cnn_model, args) + else: + run_benchmark(cnn_model, args) diff --git a/fluid/neural_machine_translation/transformer/README_cn.md b/fluid/neural_machine_translation/transformer/README_cn.md index 61d1295f7e197c0477c06e1353629a3155775128..561c5c30debc60a07050a2988bde8a70f9bc3bb5 100644 --- a/fluid/neural_machine_translation/transformer/README_cn.md +++ b/fluid/neural_machine_translation/transformer/README_cn.md @@ -9,13 +9,14 @@ ```text . ├── images # README 文档中的图片 -├── optim.py # learning rate scheduling 计算程序 +├── config.py # 训练、预测以及模型参数配置 ├── infer.py # 预测脚本 ├── model.py # 模型定义 +├── optim.py # learning rate scheduling 计算程序 ├── reader.py # 数据读取接口 ├── README.md # 文档 ├── train.py # 训练脚本 -└── config.py # 训练、预测以及模型参数配置 +└── util.py # wordpiece 数据解码工具 ``` ### 简介 @@ -58,34 +59,43 @@ Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 la ### 数据准备 -我们以 [WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)作为示例,同时参照论文中的设置使用 BPE(byte-pair encoding)[4]编码的数据,使用这种方式表示的数据能够更好的解决未登录词(out-of-vocabulary,OOV)的问题。用到的 BPE 数据可以参照[这里](https://github.com/google/seq2seq/blob/master/docs/data.md)进行下载,下载后解压,其中 `train.tok.clean.bpe.32000.en` 和 `train.tok.clean.bpe.32000.de` 为使用 BPE 的训练数据(平行语料,分别对应了英语和德语,经过了 tokenize 和 BPE 的处理),`newstest2013.tok.bpe.32000.en` 和 `newstest2013.tok.bpe.32000.de` 等为测试数据(`newstest2013.tok.en` 和 `newstest2013.tok.de` 等则为对应的未使用 BPE 的测试数据),`vocab.bpe.32000` 为相应的词典文件(源语言和目标语言共享该词典文件)。 +WMT 数据集是机器翻译领域公认的主流数据集;WMT 英德和英法数据集也是 Transformer 论文中所用数据集,其中英德数据集使用了 BPE(byte-pair encoding)[4]编码的数据,英法数据集使用了 wordpiece [5]的数据。我们这里也将使用 WMT 英德和英法翻译数据,并和论文保持一致使用 BPE 和 wordpiece 的数据,下面给出了使用的方法。对于其他自定义数据,参照下文遵循或转换为类似的数据格式即可。 + +#### WMT 英德翻译数据 -由于本示例中的数据读取脚本 `reader.py` 使用的样本数据的格式为 `\t` 分隔的的源语言和目标语言句子对(句子中的词之间使用空格分隔), 因此需要将源语言到目标语言的平行语料库文件合并为一个文件,可以执行以下命令进行合并: +[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)是一个中等规模的数据集。参照论文,英德数据集我们使用 BPE 编码的数据,这能够更好的解决未登录词(out-of-vocabulary,OOV)的问题[4]。用到的 BPE 数据可以参照[这里](https://github.com/google/seq2seq/blob/master/docs/data.md)进行下载(如果希望在自定义数据中使用 BPE 编码,可以参照[这里](https://github.com/rsennrich/subword-nmt)进行预处理),下载后解压,其中 `train.tok.clean.bpe.32000.en` 和 `train.tok.clean.bpe.32000.de` 为使用 BPE 的训练数据(平行语料,分别对应了英语和德语,经过了 tokenize 和 BPE 的处理),`newstest2013.tok.bpe.32000.en` 和 `newstest2013.tok.bpe.32000.de` 等为测试数据(`newstest2013.tok.en` 和 `newstest2013.tok.de` 等则为对应的未使用 BPE 的测试数据),`vocab.bpe.32000` 为相应的词典文件(源语言和目标语言共享该词典文件)。 + +由于本示例中的数据读取脚本 `reader.py` 默认使用的样本数据的格式为 `\t` 分隔的的源语言和目标语言句子对(默认句子中的词之间使用空格分隔),因此需要将源语言到目标语言的平行语料库文件合并为一个文件,可以执行以下命令进行合并: ```sh paste -d '\t' train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.tok.clean.bpe.32000.en-de ``` -此外,下载的词典文件 `vocab.bpe.32000` 中未包含表示序列开始、序列结束和未登录词的特殊符号,可以使用如下命令在词典中加入 `` 、`` 和 `` 作为这三个特殊符号。 +此外,下载的词典文件 `vocab.bpe.32000` 中未包含表示序列开始、序列结束和未登录词的特殊符号,可以使用如下命令在词典中加入 `` 、`` 和 `` 作为这三个特殊符号(用 BPE 表示数据已有效避免了未登录词的问题,这里加入只是做通用处理)。 ```sh sed -i '1i\\n\n' vocab.bpe.32000 ``` -对于其他自定义数据,遵循或转换为上述的数据格式即可。如果希望在自定义数据中使用 BPE 编码,可以参照[这里](https://github.com/rsennrich/subword-nmt)进行预处理。 +#### WMT 英法翻译数据 + +[WMT'14 EN-FR 数据集](http://www.statmt.org/wmt14/translation-task.html)是一个较大规模的数据集。参照论文,英法数据我们使用 wordpiece 表示的数据,wordpiece 和 BPE 类似同为采用 sub-word units 来解决 OOV 问题的方法[5]。我们提供了已完成预处理的 wordpiece 数据的下载,可以从[这里](http://transformer-data.bj.bcebos.com/wmt14_enfr.tar)下载,其中 `train.wordpiece.en-fr` 为使用 wordpiece 的训练数据,`newstest2014.wordpiece.en-fr` 为测试数据(`newstest2014.tok.en` 和 `newstest2014.tok.fr` 为对应的未经 wordpiece 处理过的测试数据,使用[脚本](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl)进行了 tokenize 的处理),`vocab.wordpiece.en-fr` 为相应的词典文件(源语言和目标语言共享该词典文件)。 + +提供的英法翻译数据无需进行额外的处理,可以直接使用;需要注意的是,这些用 wordpiece 表示的数据中句子内的 token 之间使用 `\x01` 而非空格进行分隔(因部分 token 内包含空格),这需要在训练时进行指定。 ### 模型训练 -`train.py` 是模型训练脚本,可以执行以下命令进行模型训练: +`train.py` 是模型训练脚本。以英德翻译数据为例,可以执行以下命令进行模型训练: ```sh python -u train.py \ --src_vocab_fpath data/vocab.bpe.32000 \ --trg_vocab_fpath data/vocab.bpe.32000 \ --special_token '' '' '' \ --train_file_pattern data/train.tok.clean.bpe.32000.en-de \ + --token_delimiter ' ' \ --use_token_batch True \ --batch_size 3200 \ --sort_type pool \ - --pool_size 200000 \ + --pool_size 200000 ``` -上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符)等数据相关的参数和构造 batch 方式(`use_token_batch` 指出数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看: +上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符)等数据相关的参数和构造 batch 方式(`use_token_batch` 指定了数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看: ```sh python train.py --help ``` @@ -98,19 +108,20 @@ python -u train.py \ --trg_vocab_fpath data/vocab.bpe.32000 \ --special_token '' '' '' \ --train_file_pattern data/train.tok.clean.bpe.32000.en-de \ + --token_delimiter ' ' \ --use_token_batch True \ --batch_size 3200 \ --sort_type pool \ --pool_size 200000 \ - n_layer 8 \ + n_layer 6 \ n_head 16 \ d_model 1024 \ d_inner_hid 4096 \ dropout 0.3 ``` -有关这些参数更详细信息的还请参考 `config.py` 中的注释说明。 +有关这些参数更详细信息的请参考 `config.py` 中的注释说明。对于英法翻译数据,执行训练和英德翻译训练类似,修改命令中的词典和数据文件为英法数据相应文件的路径,另外要注意的是由于英法翻译数据 token 间不是使用空格进行分隔,需要修改 `token_delimiter` 参数的设置为 `--token_delimiter '\x01'`。 -训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用CPU训练(通过参数--divice CPU),训练速度相对较慢。在训练过程中,每个 epoch 结束后将保存模型到参数 `model_dir` 指定的目录,每个 iteration 将打印如下的日志到标准输出: +训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--divice CPU` 设置),训练速度相对较慢。在训练过程中,每个 epoch 结束后将保存模型到参数 `model_dir` 指定的目录,每个 epoch 内也会每隔1000个 iteration 进行一次保存,每个 iteration 将打印如下的日志到标准输出: ```txt epoch: 0, batch: 0, sum loss: 258793.343750, avg loss: 11.069005, ppl: 64151.644531 epoch: 0, batch: 1, sum loss: 256140.718750, avg loss: 11.059616, ppl: 63552.148438 @@ -126,37 +137,45 @@ epoch: 0, batch: 9, sum loss: 245157.500000, avg loss: 10.966562, ppl: 57905.187 ### 模型预测 -`infer.py` 是模型预测脚本,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译: +`infer.py` 是模型预测脚本。以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译: ```sh python -u infer.py \ --src_vocab_fpath data/vocab.bpe.32000 \ --trg_vocab_fpath data/vocab.bpe.32000 \ --special_token '' '' '' \ --test_file_pattern data/newstest2013.tok.bpe.32000.en-de \ + --use_wordpiece False \ + --token_delimiter ' ' \ --batch_size 4 \ model_path trained_models/pass_20.infer.model \ - beam_size 5 + beam_size 5 \ max_out_len 256 ``` 和模型训练时类似,预测时也需要设置数据和 reader 相关的参数,并可以执行 `python infer.py --help` 查看这些参数的说明(部分参数意义和训练时略有不同);同样可以在预测命令中设置模型超参数,但应与模型训练时的设置一致;此外相比于模型训练,预测时还有一些额外的参数,如需要设置 `model_path` 来给出模型所在目录,可以设置 `beam_size` 和 `max_out_len` 来指定 Beam Search 算法的搜索宽度和最大深度(翻译长度),这些参数也可以在 `config.py` 中的 `InferTaskConfig` 内查阅注释说明并进行更改设置。 -执行以上预测命令会打印翻译结果到标准输出,每行输出是对应行输入的得分最高的翻译。需要注意,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要恢复成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中。 - +执行以上预测命令会打印翻译结果到标准输出,每行输出是对应行输入的得分最高的翻译。对于使用 BPE 的英德数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中(无需再次 tokenize 处理): ```sh sed 's/@@ //g' predict.txt > predict.tok.txt ``` -接下来就可以使用参考翻译(这里使用的是 `newstest2013.tok.de`)对翻译结果进行 BLEU 指标的评估了。计算 BLEU 值的一个较为广泛使用的脚本可以从[这里](https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl)获取,获取后执行如下命令: +对于英法翻译的 wordpiece 数据,执行预测和英德翻译预测类似,修改命令中的词典和数据文件为英法数据相应文件的路径,另外需要注意修改 `token_delimiter` 参数的设置为 `--token_delimiter '\x01'`;同时要修改 `use_wordpiece` 参数的设置为 `--use_wordpiece True`,这会在预测时将翻译得到的 wordpiece 数据还原为原始数据输出。为了使用 tokenize 的数据进行评估,还需要对翻译结果进行 tokenize 的处理,[Moses](https://github.com/moses-smt/mosesdecoder) 提供了一系列机器翻译相关的脚本。执行 `git clone https://github.com/moses-smt/mosesdecoder.git` 克隆 mosesdecoder 仓库后,可以使用其中的 `tokenizer.perl` 脚本对 `predict.txt` 内的翻译结果进行 tokenize 处理并输出到 `predict.tok.txt` 中,如下: +```sh +perl mosesdecoder/scripts/tokenizer/tokenizer.perl -l fr < predict.txt > predict.tok.txt +``` + +接下来就可以使用参考翻译对翻译结果进行 BLEU 指标的评估了。计算 BLEU 值的脚本也在 Moses 中包含,以英德翻译 `newstest2013.tok.de` 数据为例,执行如下命令: ```sh -perl multi-bleu.perl data/newstest2013.tok.de < predict.tok.txt +perl mosesdecoder/scripts/generic/multi-bleu.perl data/newstest2013.tok.de < predict.tok.txt ``` 可以看到类似如下的结果。 ``` BLEU = 25.08, 58.3/31.5/19.6/12.6 (BP=0.966, ratio=0.967, hyp_len=61321, ref_len=63412) ``` +目前在未使用 model average 的情况下,使用默认配置单机八卡(同论文中 base model 的配置)进行训练,英德翻译在 `newstest2013` 上测试 BLEU 值为25.,在 `newstest2014` 上测试 BLEU 值为26.;英法翻译在 `newstest2014` 上测试 BLEU 值为36.。 + ### 分布式训练 -transformer 模型支持同步或者异步的分布式训练。分布式的配置主要两个方面: +Transformer 模型支持同步或者异步的分布式训练。分布式的配置主要两个方面: 1 命令行配置 @@ -234,3 +253,4 @@ export PADDLE_PORT=6177 2. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778. 3. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016. 4. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015. +5. Wu Y, Schuster M, Chen Z, et al. [Google's neural machine translation system: Bridging the gap between human and machine translation](https://arxiv.org/pdf/1609.08144.pdf)[J]. arXiv preprint arXiv:1609.08144, 2016. diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index e3cc75d087e01c1c8a479c5632a402e067797cfc..ee1bd208c101a8cda21bc3e1fdbcb76c3b5a75b8 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -1,5 +1,7 @@ import argparse +import ast import numpy as np +from functools import partial import paddle import paddle.fluid as fluid @@ -11,6 +13,7 @@ from model import fast_decode as fast_decoder from config import * from train import pad_batch_data import reader +import util def parse_args(): @@ -46,6 +49,22 @@ def parse_args(): default=["", "", ""], nargs=3, help="The , and tokens in the dictionary.") + parser.add_argument( + "--use_wordpiece", + type=ast.literal_eval, + default=False, + help="The flag indicating if the data in wordpiece. The EN-FR data " + "we provided is wordpiece data. For wordpiece data, converting ids to " + "original words is a little different and some special codes are " + "provided in util.py to do this.") + parser.add_argument( + "--token_delimiter", + type=partial( + str.decode, encoding="string-escape"), + default=" ", + help="The delimiter used to split tokens in source or target sentences. " + "For EN-DE BPE data we provided, use spaces as token delimiter.; " + "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.") parser.add_argument( 'opts', help='See config.py for all options', @@ -320,7 +339,7 @@ def post_process_seq(seq, seq) -def py_infer(test_data, trg_idx2word): +def py_infer(test_data, trg_idx2word, use_wordpiece): """ Inference by beam search implented by python, while the calculations from symbols to probilities execute by Fluid operators. @@ -399,7 +418,10 @@ def py_infer(test_data, trg_idx2word): seqs = map(post_process_seq, batch_seqs[i]) scores = batch_scores[i] for seq in seqs: - print(" ".join([trg_idx2word[idx] for idx in seq])) + if use_wordpiece: + print(util.subword_ids_to_str(seq, trg_idx2word)) + else: + print(" ".join([trg_idx2word[idx] for idx in seq])) def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, @@ -465,7 +487,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, return input_dict -def fast_infer(test_data, trg_idx2word): +def fast_infer(test_data, trg_idx2word, use_wordpiece): """ Inference by beam search decoder based solely on Fluid operators. """ @@ -520,7 +542,9 @@ def fast_infer(test_data, trg_idx2word): trg_idx2word[idx] for idx in post_process_seq( np.array(seq_ids)[sub_start:sub_end]) - ])) + ]) if not use_wordpiece else util.subtoken_ids_to_str( + post_process_seq(np.array(seq_ids)[sub_start:sub_end]), + trg_idx2word)) scores[i].append(np.array(seq_scores)[sub_end - 1]) print hyps[i][-1] if len(hyps[i]) >= InferTaskConfig.n_best: @@ -533,8 +557,9 @@ def infer(args, inferencer=fast_infer): src_vocab_fpath=args.src_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath, fpattern=args.test_file_pattern, - batch_size=args.batch_size, + token_delimiter=args.token_delimiter, use_token_batch=False, + batch_size=args.batch_size, pool_size=args.pool_size, sort_type=reader.SortType.NONE, shuffle=False, @@ -547,7 +572,7 @@ def infer(args, inferencer=fast_infer): clip_last_batch=False) trg_idx2word = test_data.load_dict( dict_path=args.trg_vocab_fpath, reverse=True) - inferencer(test_data, trg_idx2word) + inferencer(test_data, trg_idx2word, args.use_wordpiece) if __name__ == "__main__": diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 4815cd1a58bd7028291e28476f1f857396109f8b..3ff3969b1c35f5a9c25af7be9b6683a08bd8a19e 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -12,15 +12,17 @@ class SortType(object): class Converter(object): - def __init__(self, vocab, beg, end, unk): + def __init__(self, vocab, beg, end, unk, token_delimiter): self._vocab = vocab self._beg = beg self._end = end self._unk = unk + self._token_delimiter = token_delimiter def __call__(self, sentence): return [self._beg] + [ - self._vocab.get(w, self._unk) for w in sentence.split() + self._vocab.get(w, self._unk) + for w in sentence.split(self._token_delimiter) ] + [self._end] @@ -146,9 +148,12 @@ class DataReader(object): :param use_token_batch: Whether to produce batch data according to token number. :type use_token_batch: bool - :param delimiter: The delimiter used to split source and target in each - line of data file. - :type delimiter: basestring + :param field_delimiter: The delimiter used to split source and target in + each line of data file. + :type field_delimiter: basestring + :param token_delimiter: The delimiter used to split tokens in source or + target sentences. + :type token_delimiter: basestring :param start_mark: The token representing for the beginning of sentences in dictionary. :type start_mark: basestring @@ -175,7 +180,8 @@ class DataReader(object): shuffle=True, shuffle_batch=False, use_token_batch=False, - delimiter="\t", + field_delimiter="\t", + token_delimiter=" ", start_mark="", end_mark="", unk_mark="", @@ -194,7 +200,8 @@ class DataReader(object): self._shuffle_batch = shuffle_batch self._min_length = min_length self._max_length = max_length - self._delimiter = delimiter + self._field_delimiter = field_delimiter + self._token_delimiter = token_delimiter self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, unk_mark) self._random = random.Random(x=seed) @@ -206,7 +213,8 @@ class DataReader(object): vocab=self._src_vocab, beg=self._src_vocab[start_mark], end=self._src_vocab[end_mark], - unk=self._src_vocab[unk_mark]) + unk=self._src_vocab[unk_mark], + token_delimiter=self._token_delimiter) ] if not self._only_src: converters.append( @@ -214,7 +222,8 @@ class DataReader(object): vocab=self._trg_vocab, beg=self._trg_vocab[start_mark], end=self._trg_vocab[end_mark], - unk=self._trg_vocab[unk_mark])) + unk=self._trg_vocab[unk_mark], + token_delimiter=self._token_delimiter)) converters = ComposedConverter(converters) @@ -238,17 +247,23 @@ class DataReader(object): if tar_fname is None: raise Exception("If tar file provided, please set tar_fname.") - f = tarfile.open(fpaths[0], 'r') + f = tarfile.open(fpaths[0], "r") for line in f.extractfile(tar_fname): - yield line.split(self._delimiter) + fields = line.strip("\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields else: for fpath in fpaths: if not os.path.isfile(fpath): raise IOError("Invalid file: %s" % fpath) - with open(fpath, 'r') as f: + with open(fpath, "r") as f: for line in f: - yield line.split(self._delimiter) + fields = line.strip("\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields @staticmethod def load_dict(dict_path, reverse=False): @@ -256,9 +271,9 @@ class DataReader(object): with open(dict_path, "r") as fdict: for idx, line in enumerate(fdict): if reverse: - word_dict[idx] = line.strip() + word_dict[idx] = line.strip("\n") else: - word_dict[line.strip()] = idx + word_dict[line.strip("\n")] = idx return word_dict def batch_generator(self): diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 15f2fe861ffb0419232460648fcb8a82fc521c51..d429e94f89fdb4595f79995377a40f6160a82eb9 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -3,6 +3,7 @@ import ast import multiprocessing import os import time +from functools import partial import numpy as np import paddle.fluid as fluid @@ -75,6 +76,14 @@ def parse_args(): default=["", "", ""], nargs=3, help="The , and tokens in the dictionary.") + parser.add_argument( + "--token_delimiter", + type=partial( + str.decode, encoding="string-escape"), + default=" ", + help="The delimiter used to split tokens in source or target sentences. " + "For EN-DE BPE data we provided, use spaces as token delimiter. " + "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.") parser.add_argument( 'opts', help='See config.py for all options', @@ -272,6 +281,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, src_vocab_fpath=args.src_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath, fpattern=args.val_file_pattern, + token_delimiter=args.token_delimiter, use_token_batch=args.use_token_batch, batch_size=args.batch_size * (1 if args.use_token_batch else dev_count), pool_size=args.pool_size, @@ -334,6 +344,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, src_vocab_fpath=args.src_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath, fpattern=args.train_file_pattern, + token_delimiter=args.token_delimiter, use_token_batch=args.use_token_batch, batch_size=args.batch_size * (1 if args.use_token_batch else dev_count), pool_size=args.pool_size, @@ -376,6 +387,8 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, for batch_id, data in enumerate(train_data()): feed_list = [] total_num_token = 0 + if args.local: + lr_rate = lr_scheduler.update_learning_rate() for place_id, data_buffer in enumerate( split_data( data, num_part=dev_count)): @@ -387,7 +400,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, feed_kv_pairs = data_input_dict.items() + util_input_dict.items( ) if args.local: - lr_rate = lr_scheduler.update_learning_rate() feed_kv_pairs += { lr_scheduler.learning_rate.name: lr_rate }.items() @@ -411,6 +423,10 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" % (pass_id, batch_id, total_sum_cost, total_avg_cost, np.exp([min(total_avg_cost, 100)]))) + if batch_id > 0 and batch_id % 1000 == 0: + fluid.io.save_persistables( + exe, + os.path.join(TrainTaskConfig.ckpt_dir, "latest.checkpoint")) init = True # Validate and save the model for inference. print("epoch: %d, " % pass_id + diff --git a/fluid/neural_machine_translation/transformer/util.py b/fluid/neural_machine_translation/transformer/util.py new file mode 100644 index 0000000000000000000000000000000000000000..190abf92f4f48bfc943bd99bf61a222cc6c9d2f0 --- /dev/null +++ b/fluid/neural_machine_translation/transformer/util.py @@ -0,0 +1,68 @@ +import sys +import re +import six +import unicodedata + +# Regular expression for unescaping token strings. +# '\u' is converted to '_' +# '\\' is converted to '\' +# '\213;' is converted to unichr(213) +# Inverse of escaping. +_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") + +# This set contains all letter and number characters. +_ALPHANUMERIC_CHAR_SET = set( + six.unichr(i) for i in range(sys.maxunicode) + if (unicodedata.category(six.unichr(i)).startswith("L") or + unicodedata.category(six.unichr(i)).startswith("N"))) + + +def unescape_token(escaped_token): + """ + Inverse of encoding escaping. + """ + + def match(m): + if m.group(1) is None: + return u"_" if m.group(0) == u"\\u" else u"\\" + + try: + return six.unichr(int(m.group(1))) + except (ValueError, OverflowError) as _: + return u"\u3013" # Unicode for undefined character. + + trimmed = escaped_token[:-1] if escaped_token.endswith( + "_") else escaped_token + return _UNESCAPE_REGEX.sub(match, trimmed) + + +def subtoken_ids_to_str(subtoken_ids, vocabs): + """ + Convert a list of subtoken(word piece) ids to a native string. + Refer to SubwordTextEncoder in Tensor2Tensor. + """ + subtokens = [vocabs.get(subtoken_id, u"") for subtoken_id in subtoken_ids] + + # Convert a list of subtokens to a list of tokens. + concatenated = "".join([ + t if isinstance(t, unicode) else t.decode("utf-8") for t in subtokens + ]) + split = concatenated.split("_") + tokens = [] + for t in split: + if t: + unescaped = unescape_token(t + "_") + if unescaped: + tokens.append(unescaped) + + # Convert a list of tokens to a unicode string (by inserting spaces bewteen + # word tokens). + token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] + ret = [] + for i, token in enumerate(tokens): + if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: + ret.append(u" ") + ret.append(token) + seq = "".join(ret) + + return seq.encode("utf-8") diff --git a/fluid/object_detection/.move.sh b/fluid/object_detection/.move.sh new file mode 100644 index 0000000000000000000000000000000000000000..d27f72663727d89e487a3eac9eec110818b28a97 --- /dev/null +++ b/fluid/object_detection/.move.sh @@ -0,0 +1 @@ +cp -r ./data/pascalvoc/. /home/.cache/paddle/dataset/pascalvoc diff --git a/fluid/object_detection/.run.sh b/fluid/object_detection/.run.sh new file mode 100644 index 0000000000000000000000000000000000000000..7be97aaf3dfa5344e872e8d12b0ff0d8f5405df3 --- /dev/null +++ b/fluid/object_detection/.run.sh @@ -0,0 +1,11 @@ +export MKL_NUM_THREADS=1 +export OMP_NUM_THREADS=1 +cudaid=${object_detection_cudaid:=0} # use 0-th card as default +export CUDA_VISIBLE_DEVICES=$cudaid + +if [ ! -d "/root/.cache/paddle/dataset/pascalvoc" ];then + mkdir -p /root/.cache/paddle/dataset/pascalvoc + ./data/pascalvoc/download.sh + bash ./.move.sh +fi +FLAGS_benchmark=true python train.py --batch_size=64 --num_passes=2 --for_model_ce=True --data_dir=/root/.cache/paddle/dataset/pascalvoc/ diff --git a/fluid/object_detection/train.py b/fluid/object_detection/train.py index c29bd070eda4cf82f5ac36a3eb5699ae13ae86d2..24225ab31d063cd55df86d7a31d46434578b4b6a 100644 --- a/fluid/object_detection/train.py +++ b/fluid/object_detection/train.py @@ -32,6 +32,10 @@ add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78 add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94 add_arg('is_toy', int, 0, "Toy for quick debug, 0 means using all data, while n means using only n sample.") +add_arg('for_model_ce', bool, False, "Use CE to evaluate the model") +add_arg('data_dir', str, 'data/pascalvoc', "data directory") +add_arg('skip_batch_num', int, 5, "the num of minibatch to skip.") +add_arg('iterations', int, 120, "mini batchs.") #yapf: enable @@ -148,13 +152,20 @@ def train(args, print("Pass {0}, test map {1}".format(pass_id, test_map)) return best_map + train_num = 0 + total_train_time = 0.0 for pass_id in range(num_passes): start_time = time.time() prev_start_time = start_time - end_time = 0 + # end_time = 0 + every_pass_loss = [] + iter = 0 + pass_duration = 0.0 for batch_id, data in enumerate(train_reader()): prev_start_time = start_time start_time = time.time() + if args.for_model_ce and iter == args.iterations: + break if len(data) < (devices_num * 2): print("There are too few data to train on all devices.") continue @@ -165,11 +176,28 @@ def train(args, loss_v, = exe.run(fluid.default_main_program(), feed=feeder.feed(data), fetch_list=[loss]) - end_time = time.time() + # end_time = time.time() loss_v = np.mean(np.array(loss_v)) if batch_id % 20 == 0: print("Pass {0}, batch {1}, loss {2}, time {3}".format( pass_id, batch_id, loss_v, start_time - prev_start_time)) + + if args.for_model_ce and iter >= args.skip_batch_num or pass_id != 0: + batch_duration = time.time() - start_time + pass_duration += batch_duration + train_num += len(data) + every_pass_loss.append(loss_v) + iter += 1 + total_train_time += pass_duration + + if args.for_model_ce and pass_id == num_passes - 1: + examples_per_sec = train_num / total_train_time + cost = np.mean(every_pass_loss) + with open("train_speed_factor.txt", 'w') as f: + f.write('{:f}\n'.format(examples_per_sec)) + with open("train_cost_factor.txt", 'a+') as f: + f.write('{:f}\n'.format(cost)) + best_map = test(pass_id, best_map) if pass_id % 10 == 0 or pass_id == num_passes - 1: save_model(str(pass_id)) @@ -180,11 +208,11 @@ if __name__ == '__main__': args = parser.parse_args() print_arguments(args) - data_dir = 'data/pascalvoc' - train_file_list = 'trainval.txt' - val_file_list = 'test.txt' + data_dir = args.data_dir label_file = 'label_list' model_save_dir = args.model_save_dir + train_file_list = 'trainval.txt' + val_file_list = 'test.txt' if 'coco' in args.dataset: data_dir = 'data/coco' if '2014' in args.dataset: