Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
ae47e2a8
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae47e2a8
编写于
4月 14, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine seq2seq
上级
8aca373d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
300 addition
and
527 deletion
+300
-527
seq2seq/README.md
seq2seq/README.md
+180
-0
seq2seq/reader.py
seq2seq/reader.py
+110
-45
seq2seq/run.sh
seq2seq/run.sh
+2
-0
seq2seq/seq2seq_add_attn.py
seq2seq/seq2seq_add_attn.py
+0
-293
seq2seq/train.py
seq2seq/train.py
+8
-48
seq2seq/train_ocr.py
seq2seq/train_ocr.py
+0
-140
transformer/reader.py
transformer/reader.py
+0
-1
未找到文件。
seq2seq/README.md
0 → 100644
浏览文件 @
ae47e2a8
运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.7版。如果您的 PaddlePaddle 安装版本低于此要求,请按照
[
安装文档
](
https://www.paddlepaddle.org.cn/#quick-start
)
中的说明更新 PaddlePaddle 安装版本。
# Sequence to Sequence (Seq2Seq)
以下是本范例模型的简要目录结构及说明:
```
.
├── README.md # 文档,本文件
├── args.py # 训练、预测以及模型参数配置程序
├── reader.py # 数据读入程序
├── download.py # 数据下载程序
├── train.py # 训练主程序
├── infer.py # 预测主程序
├── run.sh # 默认配置的启动脚本
├── infer.sh # 默认配置的解码脚本
├── attention_model.py # 带注意力机制的翻译模型程序
└── base_model.py # 无注意力机制的翻译模型程序
```
## 简介
Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。
本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网
[
机器翻译案例
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html
)
。
## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。
## 数据介绍
本教程使用
[
IWSLT'15 English-Vietnamese data
](
https://nlp.stanford.edu/projects/nmt/
)
数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
### 数据获取
```
python download.py
```
## 模型训练
`run.sh`
包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行:
```
sh run.sh
```
默认使用带有注意力机制的RNN模型,可以通过修改
`attention`
参数为False来训练不带注意力机制的RNN模型。
```
sh
export
CUDA_VISIBLE_DEVICES
=
0
python train.py
\
--src_lang
en
--tar_lang
vi
\
--attention
True
\
--num_layers
2
\
--hidden_size
512
\
--src_vocab_size
17191
\
--tar_vocab_size
7709
\
--batch_size
128
\
--dropout
0.2
\
--init_scale
0.1
\
--max_grad_norm
5.0
\
--train_data_prefix
data/en-vi/train
\
--eval_data_prefix
data/en-vi/tst2012
\
--test_data_prefix
data/en-vi/tst2013
\
--vocab_prefix
data/en-vi/vocab
\
--use_gpu
True
\
--model_path
./attention_models
```
训练程序会在每个epoch训练结束之后,save一次模型。
默认使用动态图模式进行训练,可以通过设置
`eager_run`
参数为False来以静态图模式进行训练,如下:
```
sh
export
CUDA_VISIBLE_DEVICES
=
0
python train.py
\
--src_lang
en
--tar_lang
vi
\
--attention
True
\
--num_layers
2
\
--hidden_size
512
\
--src_vocab_size
17191
\
--tar_vocab_size
7709
\
--batch_size
128
\
--dropout
0.2
\
--init_scale
0.1
\
--max_grad_norm
5.0
\
--train_data_prefix
data/en-vi/train
\
--eval_data_prefix
data/en-vi/tst2012
\
--test_data_prefix
data/en-vi/tst2013
\
--vocab_prefix
data/en-vi/vocab
\
--use_gpu
True
\
--model_path
./attention_models
\
--eager_run
False
```
## 模型预测
当模型训练完成之后, 可以利用infer.sh的脚本进行预测,默认使用beam search的方法进行预测,加载第10个epoch的模型进行预测,对test的数据集进行解码
```
sh infer.sh
```
如果想预测别的数据文件,只需要将 --infer_file参数进行修改。
```
sh
export
CUDA_VISIBLE_DEVICES
=
0
python infer.py
\
--attention
True
\
--src_lang
en
--tar_lang
vi
\
--num_layers
2
\
--hidden_size
512
\
--src_vocab_size
17191
\
--tar_vocab_size
7709
\
--batch_size
128
\
--dropout
0.2
\
--init_scale
0.1
\
--max_grad_norm
5.0
\
--vocab_prefix
data/en-vi/vocab
\
--infer_file
data/en-vi/tst2013.en
\
--reload_model
attention_models/epoch_10
\
--infer_output_file
attention_infer_output/infer_output.txt
\
--beam_size
10
\
--use_gpu
True
```
和训练类似,预测时同样可以以静态图模式进行,如下:
```
sh
export
CUDA_VISIBLE_DEVICES
=
0
python infer.py
\
--attention
True
\
--src_lang
en
--tar_lang
vi
\
--num_layers
2
\
--hidden_size
512
\
--src_vocab_size
17191
\
--tar_vocab_size
7709
\
--batch_size
128
\
--dropout
0.2
\
--init_scale
0.1
\
--max_grad_norm
5.0
\
--vocab_prefix
data/en-vi/vocab
\
--infer_file
data/en-vi/tst2013.en
\
--reload_model
attention_models/epoch_10
\
--infer_output_file
attention_infer_output/infer_output.txt
\
--beam_size
10
\
--use_gpu
True
--eager_run
False
```
## 效果评价
使用
[
*multi-bleu.perl*
](
https://github.com/moses-smt/mosesdecoder.git
)
工具来评价模型预测的翻译质量,使用方法如下:
```
sh
mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt
```
每个模型分别训练了10次,单次取第10个epoch保存的模型进行预测,取beam_size=10。效果如下(为了便于观察,对10次结果按照升序进行了排序):
```
> no attention
tst2012 BLEU:
[10.75 10.85 10.9 10.94 10.97 11.01 11.01 11.04 11.13 11.4]
tst2013 BLEU:
[10.71 10.71 10.74 10.76 10.91 10.94 11.02 11.16 11.21 11.44]
> with attention
tst2012 BLEU:
[21.14 22.34 22.54 22.65 22.71 22.71 23.08 23.15 23.3 23.4]
tst2013 BLEU:
[23.41 24.79 25.11 25.12 25.19 25.24 25.39 25.61 25.61 25.63]
```
seq2seq/reader.py
浏览文件 @
ae47e2a8
...
@@ -17,13 +17,58 @@ from __future__ import division
...
@@ -17,13 +17,58 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
glob
import
glob
import
six
import
os
import
io
import
io
import
numpy
as
np
import
itertools
import
itertools
from
functools
import
partial
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
,
Dataset
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
,
Dataset
def
create_data_loader
(
args
,
device
,
for_train
=
True
):
data_loaders
=
[
None
,
None
]
data_prefixes
=
[
args
.
train_data_prefix
,
args
.
eval_data_prefix
]
if
args
.
eval_data_prefix
else
[
args
.
train_data_prefix
]
for
i
,
data_prefix
in
enumerate
(
data_prefixes
):
dataset
=
Seq2SeqDataset
(
fpattern
=
data_prefix
+
"."
+
args
.
src_lang
,
trg_fpattern
=
data_prefix
+
"."
+
args
.
tar_lang
,
src_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
src_lang
,
trg_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
tar_lang
,
token_delimiter
=
None
,
start_mark
=
"<s>"
,
end_mark
=
"</s>"
,
unk_mark
=
"<unk>"
,
max_length
=
args
.
max_len
if
i
==
0
else
None
,
truncate
=
True
)
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
bos_id
,
eos_id
,
unk_id
)
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
False
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
batch_size
*
20
,
sort_type
=
SortType
.
POOL
,
shuffle
=
False
if
args
.
enable_ce
else
True
)
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
collate_fn
=
partial
(
prepare_train_input
,
bos_id
=
bos_id
,
eos_id
=
eos_id
,
pad_id
=
eos_id
),
num_workers
=
0
,
return_list
=
True
)
data_loaders
[
i
]
=
data_loader
return
data_loaders
def
prepare_train_input
(
insts
,
bos_id
,
eos_id
,
pad_id
):
def
prepare_train_input
(
insts
,
bos_id
,
eos_id
,
pad_id
):
src
,
src_length
=
pad_batch_data
(
src
,
src_length
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
pad_id
=
pad_id
)
[
inst
[
0
]
for
inst
in
insts
],
pad_id
=
pad_id
)
...
@@ -118,10 +163,11 @@ class TokenBatchCreator(object):
...
@@ -118,10 +163,11 @@ class TokenBatchCreator(object):
class
SampleInfo
(
object
):
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
def
__init__
(
self
,
i
,
lens
):
self
.
i
=
i
self
.
i
=
i
self
.
min_len
=
min_len
# to be consistent with origianl reader implementation
self
.
max_len
=
max_len
self
.
min_len
=
lens
[
0
]
self
.
max_len
=
lens
[
0
]
class
MinMaxFilter
(
object
):
class
MinMaxFilter
(
object
):
...
@@ -131,9 +177,8 @@ class MinMaxFilter(object):
...
@@ -131,9 +177,8 @@ class MinMaxFilter(object):
self
.
_creator
=
underlying_creator
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
if
(
self
.
_min_len
is
None
or
info
.
min_len
>=
self
.
_min_len
)
and
(
return
self
.
_max_len
is
None
or
info
.
max_len
<=
self
.
_max_len
):
else
:
return
self
.
_creator
.
append
(
info
)
return
self
.
_creator
.
append
(
info
)
@
property
@
property
...
@@ -151,22 +196,30 @@ class Seq2SeqDataset(Dataset):
...
@@ -151,22 +196,30 @@ class Seq2SeqDataset(Dataset):
start_mark
=
"<s>"
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
,
trg_fpattern
=
None
,
trg_fpattern
=
None
):
byte_data
=
False
,
# convert str to bytes, and use byte data
min_length
=
None
,
# field_delimiter = field_delimiter.encode("utf8")
max_length
=
None
,
# token_delimiter = token_delimiter.encode("utf8")
truncate
=
False
):
# start_mark = start_mark.encode("utf8")
if
byte_data
:
# end_mark = end_mark.encode("utf8")
# The WMT16 bpe data used here seems including bytes can not be
# unk_mark = unk_mark.encode("utf8")
# decoded by utf8. Thus convert str to bytes, and use byte data
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
field_delimiter
=
field_delimiter
.
encode
(
"utf8"
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
token_delimiter
=
token_delimiter
.
encode
(
"utf8"
)
start_mark
=
start_mark
.
encode
(
"utf8"
)
end_mark
=
end_mark
.
encode
(
"utf8"
)
unk_mark
=
unk_mark
.
encode
(
"utf8"
)
self
.
_byte_data
=
byte_data
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
,
byte_data
=
byte_data
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
,
byte_data
=
byte_data
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_only_src
=
only_src
self
.
_field_delimiter
=
field_delimiter
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
_min_length
=
min_length
self
.
_max_length
=
max_length
self
.
_truncate
=
truncate
self
.
load_src_trg_ids
(
fpattern
,
trg_fpattern
)
self
.
load_src_trg_ids
(
fpattern
,
trg_fpattern
)
def
load_src_trg_ids
(
self
,
fpattern
,
trg_fpattern
=
None
):
def
load_src_trg_ids
(
self
,
fpattern
,
trg_fpattern
=
None
):
...
@@ -195,26 +248,32 @@ class Seq2SeqDataset(Dataset):
...
@@ -195,26 +248,32 @@ class Seq2SeqDataset(Dataset):
self
.
_sample_infos
=
[]
self
.
_sample_infos
=
[]
slots
=
[
self
.
_src_seq_ids
,
self
.
_trg_seq_ids
]
slots
=
[
self
.
_src_seq_ids
,
self
.
_trg_seq_ids
]
lens
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
trg_fpattern
)):
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
trg_fpattern
)):
lens
=
[]
fields
=
converters
(
line
)
for
field
,
slot
in
zip
(
converters
(
line
),
slots
):
lens
=
[
len
(
field
)
for
field
in
fields
]
slot
.
append
(
field
)
sample
=
SampleInfo
(
i
,
lens
)
lens
.
append
(
len
(
field
))
if
(
self
.
_min_length
is
None
or
# self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
sample
.
min_len
>=
self
.
_min_length
)
and
(
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
lens
[
0
],
lens
[
0
]))
self
.
_max_length
is
None
or
sample
.
max_len
<=
self
.
_max_length
or
self
.
_truncate
):
for
field
,
slot
in
zip
(
fields
,
slots
):
slot
.
append
(
field
[:
self
.
_max_length
]
if
self
.
_truncate
and
self
.
_max_length
is
not
None
else
field
)
self
.
_sample_infos
.
append
(
sample
)
def
_load_lines
(
self
,
fpattern
,
trg_fpattern
=
None
):
def
_load_lines
(
self
,
fpattern
,
trg_fpattern
=
None
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
sorted
(
fpaths
)
# TODO: Add custum sort
fpaths
=
sorted
(
fpaths
)
# TODO: Add custum sort
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
(
f_mode
,
f_encoding
,
endl
)
=
(
"rb"
,
None
,
b
"
\n
"
)
if
self
.
_byte_data
else
(
"r"
,
"utf8"
,
"
\n
"
)
if
trg_fpattern
is
None
:
if
trg_fpattern
is
None
:
for
fpath
in
fpaths
:
for
fpath
in
fpaths
:
# with io.open(fpath, "rb") as f:
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
with
io
.
open
(
fpath
,
"r"
,
encoding
=
"utf8"
)
as
f
:
for
line
in
f
:
for
line
in
f
:
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
fields
=
line
.
strip
(
endl
).
split
(
self
.
_field_delimiter
)
yield
fields
yield
fields
else
:
else
:
# separated source and target language data files
# separated source and target language data files
...
@@ -228,24 +287,24 @@ class Seq2SeqDataset(Dataset):
...
@@ -228,24 +287,24 @@ class Seq2SeqDataset(Dataset):
with that of source language"
with that of source language"
for
fpath
,
trg_fpath
in
zip
(
fpaths
,
trg_fpaths
):
for
fpath
,
trg_fpath
in
zip
(
fpaths
,
trg_fpaths
):
# with io.open(fpath, "rb") as f:
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
# with io.open(trg_fpath, "rb") as trg_f:
with
io
.
open
(
with
io
.
open
(
fpath
,
"r"
,
encoding
=
"utf8"
)
as
f
:
trg_fpath
,
f_mode
,
encoding
=
f_encoding
)
as
trg_f
:
with
io
.
open
(
trg_fpath
,
"r"
,
encoding
=
"utf8"
)
as
trg_f
:
for
line
in
zip
(
f
,
trg_f
):
for
line
in
zip
(
f
,
trg_f
):
fields
=
[
field
.
strip
(
"
\n
"
)
for
field
in
line
]
fields
=
[
field
.
strip
(
endl
)
for
field
in
line
]
yield
fields
yield
fields
@
staticmethod
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
def
load_dict
(
dict_path
,
reverse
=
False
,
byte_data
=
False
):
word_dict
=
{}
word_dict
=
{}
# with io.open(dict_path, "rb") as fdict:
(
f_mode
,
f_encoding
,
with
io
.
open
(
dict_path
,
"r"
,
encoding
=
"utf8"
)
as
fdict
:
endl
)
=
(
"rb"
,
None
,
b
"
\n
"
)
if
byte_data
else
(
"r"
,
"utf8"
,
"
\n
"
)
with
io
.
open
(
dict_path
,
f_mode
,
encoding
=
f_encoding
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
word_dict
[
idx
]
=
line
.
strip
(
endl
)
else
:
else
:
word_dict
[
line
.
strip
(
"
\n
"
)]
=
idx
word_dict
[
line
.
strip
(
endl
)]
=
idx
return
word_dict
return
word_dict
def
get_vocab_summary
(
self
):
def
get_vocab_summary
(
self
):
...
@@ -266,19 +325,21 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -266,19 +325,21 @@ class Seq2SeqBatchSampler(BatchSampler):
batch_size
,
batch_size
,
pool_size
=
10000
,
pool_size
=
10000
,
sort_type
=
SortType
.
NONE
,
sort_type
=
SortType
.
NONE
,
min_length
=
0
,
min_length
=
None
,
max_length
=
100
,
max_length
=
None
,
shuffle
=
False
,
shuffle
=
False
,
shuffle_batch
=
False
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
use_token_batch
=
False
,
clip_last_batch
=
False
,
clip_last_batch
=
False
,
seed
=
None
):
distribute_mode
=
True
,
seed
=
0
):
for
arg
,
value
in
locals
().
items
():
for
arg
,
value
in
locals
().
items
():
if
arg
!=
"self"
:
if
arg
!=
"self"
:
setattr
(
self
,
"_"
+
arg
,
value
)
setattr
(
self
,
"_"
+
arg
,
value
)
self
.
_random
=
np
.
random
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
self
.
_random
.
seed
(
seed
)
# for multi-devices
# for multi-devices
self
.
_distribute_mode
=
distribute_mode
self
.
_nranks
=
ParallelEnv
().
nranks
self
.
_nranks
=
ParallelEnv
().
nranks
self
.
_local_rank
=
ParallelEnv
().
local_rank
self
.
_local_rank
=
ParallelEnv
().
local_rank
self
.
_device_id
=
ParallelEnv
().
dev_id
self
.
_device_id
=
ParallelEnv
().
dev_id
...
@@ -337,11 +398,14 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -337,11 +398,14 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device
# for multi-device
for
batch_id
,
batch
in
enumerate
(
batches
):
for
batch_id
,
batch
in
enumerate
(
batches
):
if
batch_id
%
self
.
_nranks
==
self
.
_local_rank
:
if
not
self
.
_distribute_mode
or
(
batch_id
%
self
.
_nranks
==
self
.
_local_rank
):
batch_indices
=
[
info
.
i
for
info
in
batch
]
batch_indices
=
[
info
.
i
for
info
in
batch
]
yield
batch_indices
yield
batch_indices
if
self
.
_local_rank
>
len
(
batches
)
%
self
.
_nranks
:
if
self
.
_distribute_mode
and
len
(
batches
)
%
self
.
_nranks
!=
0
:
yield
batch_indices
if
self
.
_local_rank
>=
len
(
batches
)
%
self
.
_nranks
:
# use previous data to pad
yield
batch_indices
def
__len__
(
self
):
def
__len__
(
self
):
if
not
self
.
_use_token_batch
:
if
not
self
.
_use_token_batch
:
...
@@ -349,5 +413,6 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -349,5 +413,6 @@ class Seq2SeqBatchSampler(BatchSampler):
len
(
self
.
_dataset
)
+
self
.
_batch_size
*
self
.
_nranks
-
1
)
//
(
len
(
self
.
_dataset
)
+
self
.
_batch_size
*
self
.
_nranks
-
1
)
//
(
self
.
_batch_size
*
self
.
_nranks
)
self
.
_batch_size
*
self
.
_nranks
)
else
:
else
:
batch_number
=
100
# TODO(guosheng): fix the uncertain length
batch_number
=
1
return
batch_number
return
batch_number
seq2seq/run.sh
浏览文件 @
ae47e2a8
export
CUDA_VISIBLE_DEVICES
=
0
python train.py
\
python train.py
\
--src_lang
en
--tar_lang
vi
\
--src_lang
en
--tar_lang
vi
\
--attention
True
\
--attention
True
\
...
...
seq2seq/seq2seq_add_attn.py
已删除
100644 → 0
浏览文件 @
8aca373d
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
,
BatchNorm
,
Embedding
,
GRUUnit
from
text
import
DynamicDecode
,
RNN
,
RNNCell
from
model
import
Model
,
Loss
class
ConvBNPool
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
out_ch
,
channels
,
act
=
"relu"
,
is_test
=
False
,
pool
=
True
,
use_cudnn
=
True
):
super
(
ConvBNPool
,
self
).
__init__
()
self
.
pool
=
pool
filter_size
=
3
conv_std_0
=
(
2.0
/
(
filter_size
**
2
*
channels
[
0
]))
**
0.5
conv_param_0
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
conv_std_0
))
conv_std_1
=
(
2.0
/
(
filter_size
**
2
*
channels
[
1
]))
**
0.5
conv_param_1
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
conv_std_1
))
self
.
conv_0_layer
=
Conv2D
(
channels
[
0
],
out_ch
[
0
],
3
,
padding
=
1
,
param_attr
=
conv_param_0
,
bias_attr
=
False
,
act
=
None
,
use_cudnn
=
use_cudnn
)
self
.
bn_0_layer
=
BatchNorm
(
out_ch
[
0
],
act
=
act
,
is_test
=
is_test
)
self
.
conv_1_layer
=
Conv2D
(
out_ch
[
0
],
num_filters
=
out_ch
[
1
],
filter_size
=
3
,
padding
=
1
,
param_attr
=
conv_param_1
,
bias_attr
=
False
,
act
=
None
,
use_cudnn
=
use_cudnn
)
self
.
bn_1_layer
=
BatchNorm
(
out_ch
[
1
],
act
=
act
,
is_test
=
is_test
)
if
self
.
pool
:
self
.
pool_layer
=
Pool2D
(
pool_size
=
2
,
pool_type
=
'max'
,
pool_stride
=
2
,
use_cudnn
=
use_cudnn
,
ceil_mode
=
True
)
def
forward
(
self
,
inputs
):
conv_0
=
self
.
conv_0_layer
(
inputs
)
bn_0
=
self
.
bn_0_layer
(
conv_0
)
conv_1
=
self
.
conv_1_layer
(
bn_0
)
bn_1
=
self
.
bn_1_layer
(
conv_1
)
if
self
.
pool
:
bn_pool
=
self
.
pool_layer
(
bn_1
)
return
bn_pool
return
bn_1
class
OCRConv
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
is_test
=
False
,
use_cudnn
=
True
):
super
(
OCRConv
,
self
).
__init__
()
self
.
conv_bn_pool_1
=
ConvBNPool
(
[
16
,
16
],
[
1
,
16
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
conv_bn_pool_2
=
ConvBNPool
(
[
32
,
32
],
[
16
,
32
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
conv_bn_pool_3
=
ConvBNPool
(
[
64
,
64
],
[
32
,
64
],
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
conv_bn_pool_4
=
ConvBNPool
(
[
128
,
128
],
[
64
,
128
],
is_test
=
is_test
,
pool
=
False
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
inputs_1
=
self
.
conv_bn_pool_1
(
inputs
)
inputs_2
=
self
.
conv_bn_pool_2
(
inputs_1
)
inputs_3
=
self
.
conv_bn_pool_3
(
inputs_2
)
inputs_4
=
self
.
conv_bn_pool_4
(
inputs_3
)
return
inputs_4
class
SimpleAttention
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
decoder_size
):
super
(
SimpleAttention
,
self
).
__init__
()
self
.
fc1
=
Linear
(
decoder_size
,
decoder_size
,
bias_attr
=
False
)
self
.
fc2
=
Linear
(
decoder_size
,
1
,
bias_attr
=
False
)
def
forward
(
self
,
encoder_vec
,
encoder_proj
,
decoder_state
):
decoder_state
=
self
.
fc1
(
decoder_state
)
decoder_state
=
fluid
.
layers
.
unsqueeze
(
decoder_state
,
[
1
])
mix
=
fluid
.
layers
.
elementwise_add
(
encoder_proj
,
decoder_state
)
mix
=
fluid
.
layers
.
tanh
(
x
=
mix
)
attn_score
=
self
.
fc2
(
mix
)
attn_scores
=
layers
.
squeeze
(
attn_score
,
[
2
])
attn_scores
=
fluid
.
layers
.
softmax
(
attn_scores
)
scaled
=
fluid
.
layers
.
elementwise_mul
(
x
=
encoder_vec
,
y
=
attn_scores
,
axis
=
0
)
context
=
fluid
.
layers
.
reduce_sum
(
scaled
,
dim
=
1
)
return
context
class
GRUCell
(
RNNCell
):
def
__init__
(
self
,
input_size
,
hidden_size
,
param_attr
=
None
,
bias_attr
=
None
,
gate_activation
=
'sigmoid'
,
candidate_activation
=
'tanh'
,
origin_mode
=
False
):
super
(
GRUCell
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
fc_layer
=
Linear
(
input_size
,
hidden_size
*
3
,
param_attr
=
param_attr
,
bias_attr
=
False
)
self
.
gru_unit
=
GRUUnit
(
hidden_size
*
3
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
activation
=
candidate_activation
,
gate_activation
=
gate_activation
,
origin_mode
=
origin_mode
)
def
forward
(
self
,
inputs
,
states
):
# step_outputs, new_states = cell(step_inputs, states)
# for GRUCell, `step_outputs` and `new_states` both are hidden
x
=
self
.
fc_layer
(
inputs
)
hidden
,
_
,
_
=
self
.
gru_unit
(
x
,
states
)
return
hidden
,
hidden
@
property
def
state_shape
(
self
):
return
[
self
.
hidden_size
]
class
EncoderNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
decoder_size
,
rnn_hidden_size
=
200
,
is_test
=
False
,
use_cudnn
=
True
):
super
(
EncoderNet
,
self
).
__init__
()
self
.
rnn_hidden_size
=
rnn_hidden_size
para_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
0.02
))
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.0
,
0.02
),
learning_rate
=
2.0
)
self
.
ocr_convs
=
OCRConv
(
is_test
=
is_test
,
use_cudnn
=
use_cudnn
)
self
.
gru_forward_layer
=
RNN
(
cell
=
GRUCell
(
input_size
=
128
*
6
,
# channel * h
hidden_size
=
rnn_hidden_size
,
param_attr
=
para_attr
,
bias_attr
=
bias_attr
,
candidate_activation
=
'relu'
),
is_reverse
=
False
,
time_major
=
False
)
self
.
gru_backward_layer
=
RNN
(
cell
=
GRUCell
(
input_size
=
128
*
6
,
# channel * h
hidden_size
=
rnn_hidden_size
,
param_attr
=
para_attr
,
bias_attr
=
bias_attr
,
candidate_activation
=
'relu'
),
is_reverse
=
True
,
time_major
=
False
)
self
.
encoded_proj_fc
=
Linear
(
rnn_hidden_size
*
2
,
decoder_size
,
bias_attr
=
False
)
def
forward
(
self
,
inputs
):
conv_features
=
self
.
ocr_convs
(
inputs
)
transpose_conv_features
=
fluid
.
layers
.
transpose
(
conv_features
,
perm
=
[
0
,
3
,
1
,
2
])
sliced_feature
=
fluid
.
layers
.
reshape
(
transpose_conv_features
,
[
-
1
,
transpose_conv_features
.
shape
[
1
],
transpose_conv_features
.
shape
[
2
]
*
transpose_conv_features
.
shape
[
3
]
],
inplace
=
False
)
gru_forward
,
_
=
self
.
gru_forward_layer
(
sliced_feature
)
gru_backward
,
_
=
self
.
gru_backward_layer
(
sliced_feature
)
encoded_vector
=
fluid
.
layers
.
concat
(
input
=
[
gru_forward
,
gru_backward
],
axis
=
2
)
encoded_proj
=
self
.
encoded_proj_fc
(
encoded_vector
)
return
gru_backward
,
encoded_vector
,
encoded_proj
class
DecoderCell
(
RNNCell
):
def
__init__
(
self
,
encoder_size
,
decoder_size
):
super
(
DecoderCell
,
self
).
__init__
()
self
.
attention
=
SimpleAttention
(
decoder_size
)
self
.
gru_cell
=
GRUCell
(
input_size
=
encoder_size
*
2
+
decoder_size
,
# encoded_vector.shape[-1] + embed_size
hidden_size
=
decoder_size
)
def
forward
(
self
,
current_word
,
states
,
encoder_vec
,
encoder_proj
):
context
=
self
.
attention
(
encoder_vec
,
encoder_proj
,
states
)
decoder_inputs
=
layers
.
concat
([
current_word
,
context
],
axis
=
1
)
hidden
,
_
=
self
.
gru_cell
(
decoder_inputs
,
states
)
return
hidden
,
hidden
class
GRUDecoderWithAttention
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
encoder_size
,
decoder_size
,
num_classes
):
super
(
GRUDecoderWithAttention
,
self
).
__init__
()
self
.
gru_attention
=
RNN
(
DecoderCell
(
encoder_size
,
decoder_size
),
is_reverse
=
False
,
time_major
=
False
)
self
.
out_layer
=
Linear
(
input_dim
=
decoder_size
,
output_dim
=
num_classes
+
2
,
bias_attr
=
None
,
act
=
'softmax'
)
def
forward
(
self
,
inputs
,
decoder_initial_states
,
encoder_vec
,
encoder_proj
):
out
,
_
=
self
.
gru_attention
(
inputs
,
initial_states
=
decoder_initial_states
,
encoder_vec
=
encoder_vec
,
encoder_proj
=
encoder_proj
)
predict
=
self
.
out_layer
(
out
)
return
predict
class
OCRAttention
(
Model
):
def
__init__
(
self
,
num_classes
,
encoder_size
,
decoder_size
,
word_vector_dim
):
super
(
OCRAttention
,
self
).
__init__
()
self
.
encoder_net
=
EncoderNet
(
decoder_size
)
self
.
fc
=
Linear
(
input_dim
=
encoder_size
,
output_dim
=
decoder_size
,
bias_attr
=
False
,
act
=
'relu'
)
self
.
embedding
=
Embedding
(
[
num_classes
+
2
,
word_vector_dim
],
dtype
=
'float32'
)
self
.
gru_decoder_with_attention
=
GRUDecoderWithAttention
(
encoder_size
,
decoder_size
,
num_classes
)
def
forward
(
self
,
inputs
,
label_in
):
gru_backward
,
encoded_vector
,
encoded_proj
=
self
.
encoder_net
(
inputs
)
decoder_boot
=
self
.
fc
(
gru_backward
[:,
0
])
trg_embedding
=
self
.
embedding
(
label_in
)
prediction
=
self
.
gru_decoder_with_attention
(
trg_embedding
,
decoder_boot
,
encoded_vector
,
encoded_proj
)
return
prediction
class
CrossEntropyCriterion
(
Loss
):
def
__init__
(
self
):
super
(
CrossEntropyCriterion
,
self
).
__init__
()
def
forward
(
self
,
outputs
,
labels
):
predict
,
(
label
,
mask
)
=
outputs
[
0
],
labels
loss
=
layers
.
cross_entropy
(
predict
,
label
=
label
,
soft_label
=
False
)
loss
=
layers
.
elementwise_mul
(
loss
,
mask
,
axis
=
0
)
loss
=
layers
.
reduce_sum
(
loss
)
return
loss
seq2seq/train.py
浏览文件 @
ae47e2a8
...
@@ -28,7 +28,7 @@ from callbacks import ProgBarLogger
...
@@ -28,7 +28,7 @@ from callbacks import ProgBarLogger
from
args
import
parse_args
from
args
import
parse_args
from
seq2seq_base
import
BaseModel
,
CrossEntropyCriterion
from
seq2seq_base
import
BaseModel
,
CrossEntropyCriterion
from
seq2seq_attn
import
AttentionModel
from
seq2seq_attn
import
AttentionModel
from
reader
import
Seq2SeqDataset
,
Seq2SeqBatchSampler
,
SortType
,
prepare_train_input
from
reader
import
create_data_loader
def
do_train
(
args
):
def
do_train
(
args
):
...
@@ -38,7 +38,6 @@ def do_train(args):
...
@@ -38,7 +38,6 @@ def do_train(args):
if
args
.
enable_ce
:
if
args
.
enable_ce
:
fluid
.
default_main_program
().
random_seed
=
102
fluid
.
default_main_program
().
random_seed
=
102
fluid
.
default_startup_program
().
random_seed
=
102
fluid
.
default_startup_program
().
random_seed
=
102
args
.
shuffle
=
False
# define model
# define model
inputs
=
[
inputs
=
[
...
@@ -54,64 +53,25 @@ def do_train(args):
...
@@ -54,64 +53,25 @@ def do_train(args):
labels
=
[
Input
([
None
,
None
,
1
],
"int64"
,
name
=
"label"
),
]
labels
=
[
Input
([
None
,
None
,
1
],
"int64"
,
name
=
"label"
),
]
# def dataloader
# def dataloader
data_loaders
=
[
None
,
None
]
train_loader
,
eval_loader
=
create_data_loader
(
args
,
device
)
data_prefixes
=
[
args
.
train_data_prefix
,
args
.
eval_data_prefix
]
if
args
.
eval_data_prefix
else
[
args
.
train_data_prefix
]
for
i
,
data_prefix
in
enumerate
(
data_prefixes
):
dataset
=
Seq2SeqDataset
(
fpattern
=
data_prefix
+
"."
+
args
.
src_lang
,
trg_fpattern
=
data_prefix
+
"."
+
args
.
tar_lang
,
src_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
src_lang
,
trg_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
tar_lang
,
token_delimiter
=
None
,
start_mark
=
"<s>"
,
end_mark
=
"</s>"
,
unk_mark
=
"<unk>"
)
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
bos_id
,
eos_id
,
unk_id
)
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
False
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
batch_size
*
20
,
sort_type
=
SortType
.
POOL
,
shuffle
=
args
.
shuffle
)
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
feed_list
=
None
if
fluid
.
in_dygraph_mode
()
else
[
x
.
forward
()
for
x
in
inputs
+
labels
],
collate_fn
=
partial
(
prepare_train_input
,
bos_id
=
bos_id
,
eos_id
=
eos_id
,
pad_id
=
eos_id
),
num_workers
=
0
,
return_list
=
True
)
data_loaders
[
i
]
=
data_loader
train_loader
,
eval_loader
=
data_loaders
model_maker
=
AttentionModel
if
args
.
attention
else
BaseModel
model_maker
=
AttentionModel
if
args
.
attention
else
BaseModel
model
=
model_maker
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
model
=
model_maker
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
args
.
hidden_size
,
args
.
hidden_size
,
args
.
num_layers
,
args
.
hidden_size
,
args
.
hidden_size
,
args
.
num_layers
,
args
.
dropout
)
args
.
dropout
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
,
parameter_list
=
model
.
parameters
())
optimizer
.
_grad_clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
args
.
max_grad_norm
)
model
.
prepare
(
model
.
prepare
(
fluid
.
optimizer
.
Adam
(
optimizer
,
CrossEntropyCriterion
(),
inputs
=
inputs
,
labels
=
labels
)
learning_rate
=
args
.
learning_rate
,
parameter_list
=
model
.
parameters
()),
CrossEntropyCriterion
(),
inputs
=
inputs
,
labels
=
labels
)
model
.
fit
(
train_data
=
train_loader
,
model
.
fit
(
train_data
=
train_loader
,
eval_data
=
eval_loader
,
eval_data
=
eval_loader
,
epochs
=
args
.
max_epoch
,
epochs
=
args
.
max_epoch
,
eval_freq
=
1
,
eval_freq
=
1
,
save_freq
=
1
,
save_freq
=
1
,
save_dir
=
args
.
model_path
,
save_dir
=
args
.
model_path
,
log_freq
=
1
,
log_freq
=
1
)
verbose
=
2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
seq2seq/train_ocr.py
已删除
100644 → 0
浏览文件 @
8aca373d
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
import
paddle.fluid.profiler
as
profiler
import
paddle.fluid
as
fluid
import
data_reader
from
paddle.fluid.dygraph.base
import
to_variable
import
argparse
import
functools
from
utility
import
add_arguments
,
print_arguments
,
get_attention_feeder_data
from
model
import
Input
,
set_device
from
nets
import
OCRAttention
,
CrossEntropyCriterion
from
eval
import
evaluate
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'batch_size'
,
int
,
32
,
"Minibatch size."
)
add_arg
(
'epoch_num'
,
int
,
30
,
"Epoch number."
)
add_arg
(
'lr'
,
float
,
0.001
,
"Learning rate."
)
add_arg
(
'lr_decay_strategy'
,
str
,
""
,
"Learning rate decay strategy."
)
add_arg
(
'log_period'
,
int
,
200
,
"Log period."
)
add_arg
(
'save_model_period'
,
int
,
2000
,
"Save model period. '-1' means never saving the model."
)
add_arg
(
'eval_period'
,
int
,
2000
,
"Evaluate period. '-1' means never evaluating the model."
)
add_arg
(
'save_model_dir'
,
str
,
"./output"
,
"The directory the model to be saved to."
)
add_arg
(
'train_images'
,
str
,
None
,
"The directory of images to be used for training."
)
add_arg
(
'train_list'
,
str
,
None
,
"The list file of images to be used for training."
)
add_arg
(
'test_images'
,
str
,
None
,
"The directory of images to be used for test."
)
add_arg
(
'test_list'
,
str
,
None
,
"The list file of images to be used for training."
)
add_arg
(
'init_model'
,
str
,
None
,
"The init model file of directory."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Whether use GPU to train."
)
add_arg
(
'parallel'
,
bool
,
False
,
"Whether use parallel training."
)
add_arg
(
'profile'
,
bool
,
False
,
"Whether to use profiling."
)
add_arg
(
'skip_batch_num'
,
int
,
0
,
"The number of first minibatches to skip as warm-up for better performance test."
)
add_arg
(
'skip_test'
,
bool
,
False
,
"Whether to skip test phase."
)
# model hyper paramters
add_arg
(
'encoder_size'
,
int
,
200
,
"Encoder size."
)
add_arg
(
'decoder_size'
,
int
,
128
,
"Decoder size."
)
add_arg
(
'word_vector_dim'
,
int
,
128
,
"Word vector dim."
)
add_arg
(
'num_classes'
,
int
,
95
,
"Number classes."
)
add_arg
(
'gradient_clip'
,
float
,
5.0
,
"Gradient clip value."
)
add_arg
(
'dynamic'
,
bool
,
False
,
"Whether to use dygraph."
)
def
train
(
args
):
device
=
set_device
(
"gpu"
if
args
.
use_gpu
else
"cpu"
)
fluid
.
enable_dygraph
(
device
)
if
args
.
dynamic
else
None
ocr_attention
=
OCRAttention
(
encoder_size
=
args
.
encoder_size
,
decoder_size
=
args
.
decoder_size
,
num_classes
=
args
.
num_classes
,
word_vector_dim
=
args
.
word_vector_dim
)
LR
=
args
.
lr
if
args
.
lr_decay_strategy
==
"piecewise_decay"
:
learning_rate
=
fluid
.
layers
.
piecewise_decay
([
200000
,
250000
],
[
LR
,
LR
*
0.1
,
LR
*
0.01
])
else
:
learning_rate
=
LR
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
,
parameter_list
=
ocr_attention
.
parameters
())
# grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(args.gradient_clip)
inputs
=
[
Input
([
None
,
1
,
48
,
384
],
"float32"
,
name
=
"pixel"
),
Input
([
None
,
None
],
"int64"
,
name
=
"label_in"
),
]
labels
=
[
Input
([
None
,
None
],
"int64"
,
name
=
"label_out"
),
Input
([
None
,
None
],
"float32"
,
name
=
"mask"
)]
ocr_attention
.
prepare
(
optimizer
,
CrossEntropyCriterion
(),
inputs
=
inputs
,
labels
=
labels
)
train_reader
=
data_reader
.
data_reader
(
args
.
batch_size
,
shuffle
=
True
,
images_dir
=
args
.
train_images
,
list_file
=
args
.
train_list
,
data_type
=
'train'
)
# test_reader = data_reader.data_reader(
# args.batch_size,
# images_dir=args.test_images,
# list_file=args.test_list,
# data_type="test")
# if not os.path.exists(args.save_model_dir):
# os.makedirs(args.save_model_dir)
total_step
=
0
epoch_num
=
args
.
epoch_num
for
epoch
in
range
(
epoch_num
):
batch_id
=
0
total_loss
=
0.0
for
data
in
train_reader
():
total_step
+=
1
data_dict
=
get_attention_feeder_data
(
data
)
pixel
=
data_dict
[
"pixel"
]
label_in
=
data_dict
[
"label_in"
].
reshape
([
pixel
.
shape
[
0
],
-
1
])
label_out
=
data_dict
[
"label_out"
].
reshape
([
pixel
.
shape
[
0
],
-
1
])
mask
=
data_dict
[
"mask"
].
reshape
(
label_out
.
shape
).
astype
(
"float32"
)
avg_loss
=
ocr_attention
.
train
(
inputs
=
[
pixel
,
label_in
],
labels
=
[
label_out
,
mask
])[
0
]
total_loss
+=
avg_loss
if
True
:
#batch_id > 0 and batch_id % args.log_period == 0:
print
(
"epoch: {}, batch_id: {}, loss {}"
.
format
(
epoch
,
batch_id
,
total_loss
/
args
.
batch_size
/
args
.
log_period
))
total_loss
=
0.0
batch_id
+=
1
if
__name__
==
'__main__'
:
args
=
parser
.
parse_args
()
print_arguments
(
args
)
if
args
.
profile
:
if
args
.
use_gpu
:
with
profiler
.
cuda_profiler
(
"cuda_profiler.txt"
,
'csv'
)
as
nvprof
:
train
(
args
)
else
:
with
profiler
.
profiler
(
"CPU"
,
sorted_key
=
'total'
)
as
cpuprof
:
train
(
args
)
else
:
train
(
args
)
\ No newline at end of file
transformer/reader.py
浏览文件 @
ae47e2a8
...
@@ -289,7 +289,6 @@ class Seq2SeqDataset(Dataset):
...
@@ -289,7 +289,6 @@ class Seq2SeqDataset(Dataset):
start_mark
=
"<s>"
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
,
trg_fpattern
=
None
,
trg_fpattern
=
None
,
byte_data
=
False
):
byte_data
=
False
):
if
byte_data
:
if
byte_data
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录