未验证 提交 7a5cca69 编写于 作者: W whs 提交者: GitHub

Refine arguments of ocr models. (#3261)

* Refine arguments of ocr models.

* Add comments for issue of attention model in paddle1.5
上级 25932361
>注意:在paddle1.5版本上训练attention model有收敛问题,建议您暂时使用paddle1.4版本,后续我们会修复该问题。
## 代码结构
```
├── data_reader.py # 下载、读取、处理数据。
......@@ -6,7 +8,9 @@
├── train.py # 用于模型的训练。
├── infer.py # 加载训练好的模型文件,对新数据进行预测。
├── eval.py # 评估模型在指定数据集上的效果。
└── utils.py # 定义通用的函数。
├─ utils.py # 定义通用的函数。
├── run_crnn_ctc.sh # 执行crnn_ctc模型训练任务
└── run_attention.sh # 执行attention模型训练任务
```
......@@ -136,14 +140,14 @@ env CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --parallel=True
执行`python train.py --help`可查看更多使用方式和参数详细说明。
图2为使用默认参数在默认数据集上训练`CTC model`的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。测试集上最低错误率为22.0%.
图2为执行脚本`run_crnn_ctc.sh`在默认数据集上训练`CTC model`的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。测试集上最低错误率为22.0%.
<p align="center">
<img src="images/train.jpg" width="400" hspace='10'/> <br/>
<strong>图 2</strong>
</p>
图3为使用默认参数在默认数据集上训练`attention model`的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。测试集上最低错误率为16.25%.
图3为执行脚本`run_attention.sh`在默认数据集上训练`attention model`的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。测试集上最低错误率为16.25%.
<p align="center">
<img src="images/train_attention.jpg" width="400" hspace='10'/> <br/>
......
......@@ -22,11 +22,7 @@ word_vector_dim = 128
max_length = 100
sos = 0
eos = 1
gradient_clip = 10
LR = 1.0
beam_size = 1
learning_rate_decay = None
def conv_bn_pool(input,
group,
......@@ -192,7 +188,7 @@ def attention_train_net(args, data_shape, num_classes):
prediction = gru_decoder_with_attention(trg_embedding, encoded_vector,
encoded_proj, decoder_boot,
decoder_size, num_classes)
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(gradient_clip))
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(args.gradient_clip))
label_out = fluid.layers.cast(x=label_out, dtype='int64')
_, maxid = fluid.layers.topk(input=prediction, k=1)
......@@ -203,8 +199,8 @@ def attention_train_net(args, data_shape, num_classes):
cost = fluid.layers.cross_entropy(input=prediction, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost)
if learning_rate_decay == "piecewise_decay":
LR = args.lr
if args.lr_decay_strategy == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay([50000], [LR, LR * 0.01])
else:
learning_rate = LR
......
......@@ -184,9 +184,9 @@ def encoder_net(images,
def ctc_train_net(args, data_shape, num_classes):
L2_RATE = 0.0004
LR = 1.0e-3
MOMENTUM = 0.9
L2_RATE = args.l2decay
LR = args.lr
MOMENTUM = args.momentum
learning_rate_decay = None
regularizer = fluid.regularizer.L2Decay(L2_RATE)
......
export CUDA_VISIBLE_DEVICES=0
nohup python train.py \
--lr=1.0 \
--gradient_clip=10 \
--model="attention" \
--log_period=10 \
> attention.log 2>&1 &
tailf attention.log
export CUDA_VISIBLE_DEVICES=0
nohup python train.py \
--lr=1e-3 \
--l2decay=4e-4 \
--momentum=0.9 \
--model="crnn_ctc" \
--log_period=10 \
> crnn_ctc.log 2>&1 &
tailf crnn_ctc.log
......@@ -32,26 +32,33 @@ import numpy as np
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('total_step', int, 720000, "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test', bool, False, "Whether to skip test phase.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('lr', float, 1e-3, "Learning rate.")
add_arg('lr_decay_strategy', str, None, "Learning rate decay strategy. 'piecewise_decay' or None is valid.")
add_arg('l2decay', float, 4e-4, "L2 decay rate.")
add_arg('momentum', float, 0.9, "Momentum rate.")
add_arg('gradient_clip', float, 10.0, "The threshold of gradient clipping.")
add_arg('total_step', int, 720000, "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test', bool, False, "Whether to skip test phase.")
# yapf: enable
......@@ -131,12 +138,13 @@ def train(args):
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
print("\nTime: %s; Iter[%d]; Test seq error: %s.\n" %
(time.time(), iter_num, str(test_seq_error[0])))
print("\n[%s] - Iter[%d]; Test seq error: %s.\n" %
(time.asctime( time.localtime(time.time())), iter_num, str(test_seq_error[0])))
#Note: The following logs are special for CE monitoring.
#Other situations do not need to care about these logs.
print("kpis test_acc %f" % (1 - test_seq_error[0]))
if 'ce_mode' in os.environ:
print("kpis test_acc %f" % (1 - test_seq_error[0]))
def save_model(args, exe, iter_num):
filename = "model_%05d" % iter_num
......@@ -171,14 +179,15 @@ def train(args):
iter_num += 1
# training log
if iter_num % args.log_period == 0:
print("\nTime: %s; Iter[%d]; Avg loss: %.3f; Avg seq err: %.3f"
% (time.time(), iter_num,
print("\n[%s] - Iter[%d]; Avg loss: %.3f; Avg seq err: %.3f"
% (time.asctime( time.localtime(time.time())), iter_num,
total_loss / (args.log_period * args.batch_size),
total_seq_error / (args.log_period * args.batch_size)))
print("kpis train_cost %f" % (total_loss / (args.log_period *
if 'ce_mode' in os.environ:
print("kpis train_cost %f" % (total_loss / (args.log_period *
args.batch_size)))
print("kpis train_acc %f" % (
1 - total_seq_error / (args.log_period * args.batch_size)))
print("kpis train_acc %f" % (
1 - total_seq_error / (args.log_period * args.batch_size)))
total_loss = 0.0
total_seq_error = 0.0
......@@ -198,7 +207,8 @@ def train(args):
else:
save_model(args, exe, iter_num)
end_time = time.time()
print("kpis train_duration %f" % (end_time - start_time))
if 'ce_mode' in os.environ:
print("kpis train_duration %f" % (end_time - start_time))
# Postprocess benchmark data
latencies = batch_times[args.skip_batch_num:]
latency_avg = np.average(latencies)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册