Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
a8dc7ed3
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a8dc7ed3
编写于
10月 31, 2018
作者:
H
Hongyu Liu
提交者:
GitHub
10月 31, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1416 from phlrain/move_lm_to_nlp
move language model to PaddleNLP
上级
e2b98dc5
188abef5
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
887 addition
and
13 deletion
+887
-13
fluid/PaddleNLP/language_model/gru/.run_ce.sh
fluid/PaddleNLP/language_model/gru/.run_ce.sh
+0
-0
fluid/PaddleNLP/language_model/gru/README.md
fluid/PaddleNLP/language_model/gru/README.md
+0
-0
fluid/PaddleNLP/language_model/gru/_ce.py
fluid/PaddleNLP/language_model/gru/_ce.py
+0
-0
fluid/PaddleNLP/language_model/gru/infer.py
fluid/PaddleNLP/language_model/gru/infer.py
+0
-0
fluid/PaddleNLP/language_model/gru/train.py
fluid/PaddleNLP/language_model/gru/train.py
+8
-11
fluid/PaddleNLP/language_model/gru/train_on_cloud.py
fluid/PaddleNLP/language_model/gru/train_on_cloud.py
+2
-2
fluid/PaddleNLP/language_model/gru/utils.py
fluid/PaddleNLP/language_model/gru/utils.py
+0
-0
fluid/PaddleNLP/language_model/lstm/.run_ce.sh
fluid/PaddleNLP/language_model/lstm/.run_ce.sh
+11
-0
fluid/PaddleNLP/language_model/lstm/README.md
fluid/PaddleNLP/language_model/lstm/README.md
+76
-0
fluid/PaddleNLP/language_model/lstm/_ce.py
fluid/PaddleNLP/language_model/lstm/_ce.py
+56
-0
fluid/PaddleNLP/language_model/lstm/args.py
fluid/PaddleNLP/language_model/lstm/args.py
+40
-0
fluid/PaddleNLP/language_model/lstm/data/download_data.sh
fluid/PaddleNLP/language_model/lstm/data/download_data.sh
+4
-0
fluid/PaddleNLP/language_model/lstm/lm_model.py
fluid/PaddleNLP/language_model/lstm/lm_model.py
+285
-0
fluid/PaddleNLP/language_model/lstm/reader.py
fluid/PaddleNLP/language_model/lstm/reader.py
+105
-0
fluid/PaddleNLP/language_model/lstm/train.py
fluid/PaddleNLP/language_model/lstm/train.py
+300
-0
未找到文件。
fluid/PaddleNLP/language_model/.run_ce.sh
→
fluid/PaddleNLP/language_model/
gru/
.run_ce.sh
浏览文件 @
a8dc7ed3
文件已移动
fluid/PaddleNLP/language_model/README.md
→
fluid/PaddleNLP/language_model/
gru/
README.md
浏览文件 @
a8dc7ed3
文件已移动
fluid/PaddleNLP/language_model/_ce.py
→
fluid/PaddleNLP/language_model/
gru/
_ce.py
浏览文件 @
a8dc7ed3
文件已移动
fluid/PaddleNLP/language_model/infer.py
→
fluid/PaddleNLP/language_model/
gru/
infer.py
浏览文件 @
a8dc7ed3
文件已移动
fluid/PaddleNLP/language_model/train.py
→
fluid/PaddleNLP/language_model/
gru/
train.py
浏览文件 @
a8dc7ed3
...
...
@@ -22,10 +22,7 @@ def parse_args():
help
=
'If set, run
\
the task with continuous evaluation logs.'
)
parser
.
add_argument
(
'--num_devices'
,
type
=
int
,
default
=
1
,
help
=
'Number of GPU devices'
)
'--num_devices'
,
type
=
int
,
default
=
1
,
help
=
'Number of GPU devices'
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -129,15 +126,15 @@ def train(train_reader,
newest_ppl
=
0
for
data
in
train_reader
():
i
+=
1
lod_src_wordseq
=
utils
.
to_lodtensor
(
[
dat
[
0
]
for
dat
in
data
],
place
)
lod_dst_wordseq
=
utils
.
to_lodtensor
(
[
dat
[
1
]
for
dat
in
data
],
place
)
lod_src_wordseq
=
utils
.
to_lodtensor
(
[
dat
[
0
]
for
dat
in
data
],
place
)
lod_dst_wordseq
=
utils
.
to_lodtensor
(
[
dat
[
1
]
for
dat
in
data
],
place
)
ret_avg_cost
=
train_exe
.
run
(
feed
=
{
"src_wordseq"
:
lod_src_wordseq
,
"dst_wordseq"
:
lod_dst_wordseq
},
fetch_list
=
fetch_list
)
fetch_list
=
fetch_list
)
avg_ppl
=
np
.
exp
(
ret_avg_cost
[
0
])
newest_ppl
=
np
.
mean
(
avg_ppl
)
if
i
%
100
==
0
:
...
...
@@ -145,8 +142,8 @@ def train(train_reader,
t1
=
time
.
time
()
total_time
+=
t1
-
t0
print
(
"epoch:%d num_steps:%d time_cost(s):%f"
%
(
epoch_idx
,
i
,
total_time
/
epoch_idx
))
print
(
"epoch:%d num_steps:%d time_cost(s):%f"
%
(
epoch_idx
,
i
,
total_time
/
epoch_idx
))
if
pass_idx
==
pass_num
-
1
and
args
.
enable_ce
:
#Note: The following logs are special for CE monitoring.
...
...
fluid/PaddleNLP/language_model/train_on_cloud.py
→
fluid/PaddleNLP/language_model/
gru/
train_on_cloud.py
浏览文件 @
a8dc7ed3
...
...
@@ -236,8 +236,8 @@ def do_train(train_reader,
t1
=
time
.
time
()
total_time
+=
t1
-
t0
print
(
"epoch:%d num_steps:%d time_cost(s):%f"
%
(
epoch_idx
,
i
,
total_time
/
epoch_idx
))
print
(
"epoch:%d num_steps:%d time_cost(s):%f"
%
(
epoch_idx
,
i
,
total_time
/
epoch_idx
))
save_dir
=
"%s/epoch_%d"
%
(
model_dir
,
epoch_idx
)
feed_var_names
=
[
"src_wordseq"
,
"dst_wordseq"
]
...
...
fluid/PaddleNLP/language_model/utils.py
→
fluid/PaddleNLP/language_model/
gru/
utils.py
浏览文件 @
a8dc7ed3
文件已移动
fluid/PaddleNLP/language_model/lstm/.run_ce.sh
0 → 100644
浏览文件 @
a8dc7ed3
export
CUDA_VISIBLE_DEVICES
=
0
cd
data
sh download_data.sh
cd
..
python train.py
\
--data_path
data/simple-examples/data/
\
--model_type
small
\
--use_gpu
True
\
--enable_ce
| python _ce.py
fluid/PaddleNLP/language_model/lstm/README.md
0 → 100644
浏览文件 @
a8dc7ed3
# lstm lm
以下是本例的简要目录结构及说明:
```
text
.
├── README.md # 文档
├── train.py # 训练脚本
├── reader.py # 数据读取
└── lm_model.py # 模型定义文件
```
## 简介
循环神经网络语言模型的介绍可以参阅论文
[
Recurrent Neural Network Regularization
](
https://arxiv.org/abs/1409.2329
)
,本文主要是说明基于lstm的语言的模型的实现,数据是采用ptb dataset,下载地址为
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
## 数据下载
用户可以自行下载数据,并解压, 也可以利用目录中的脚本
cd data; sh download_data.sh
## 训练
运行命令
`CUDA_VISIBLE_DEVICES=0 python train.py --data_path data/simple-examples/data/ --model_type small --use_gpu True`
开始训练模型。
model_type 为模型配置的大小,目前支持 small,medium, large 三种配置形式
实现采用双层的lstm,具体的参数和网络配置 可以参考 train.py, lm_model.py 文件中的设置
## 训练结果示例
p40中训练日志如下(small config), test 测试集仅在最后一个epoch完成后进行测试
```
text
epoch id 0
ppl 232 865.86505 1.0
ppl 464 632.76526 1.0
ppl 696 510.47153 1.0
ppl 928 437.60617 1.0
ppl 1160 393.38422 1.0
ppl 1392 353.05365 1.0
ppl 1624 325.73267 1.0
ppl 1856 305.488 1.0
ppl 2088 286.3128 1.0
ppl 2320 270.91504 1.0
train ppl 270.86246
valid ppl 181.867964379
...
ppl 2320 40.975872 0.001953125
train ppl 40.974102
valid ppl 117.85741214
test ppl 113.939103843
```
## 与tf结果对比
tf采用的版本是1.6
```
text
small config
train valid test
fluid 1.0 40.962 118.111 112.617
tf 1.6 40.492 118.329 113.788
medium config
train valid test
fluid 1.0 45.620 87.398 83.682
tf 1.6 45.594 87.363 84.015
large config
train valid test
fluid 1.0 37.221 82.358 78.137
tf 1.6 38.342 82.311 78.121
```
fluid/PaddleNLP/language_model/lstm/_ce.py
0 → 100644
浏览文件 @
a8dc7ed3
# this file is only used for continuous evaluation test!
import
os
import
sys
sys
.
path
.
append
(
os
.
environ
[
'ceroot'
])
from
kpi
import
CostKpi
from
kpi
import
DurationKpi
imikolov_20_avg_ppl_kpi
=
CostKpi
(
'lstm_language_model_loss'
,
0.02
,
0
)
imikolov_20_pass_duration_kpi
=
DurationKpi
(
'lstm_language_model_duration'
,
0.02
,
0
,
actived
=
True
)
tracking_kpis
=
[
imikolov_20_avg_ppl_kpi
,
imikolov_20_pass_duration_kpi
,
]
def
parse_log
(
log
):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost
\t
1.0
test_cost
\t
1.0
train_cost
\t
1.0
train_cost
\t
1.0
train_acc
\t
1.2
"
'''
for
line
in
log
.
split
(
'
\n
'
):
fs
=
line
.
strip
().
split
(
'
\t
'
)
print
(
fs
)
kpi_name
=
fs
[
0
]
kpi_value
=
float
(
fs
[
1
])
yield
kpi_name
,
kpi_value
def
log_to_ce
(
log
):
kpi_tracker
=
{}
for
kpi
in
tracking_kpis
:
kpi_tracker
[
kpi
.
name
]
=
kpi
for
(
kpi_name
,
kpi_value
)
in
parse_log
(
log
):
print
(
kpi_name
,
kpi_value
)
kpi_tracker
[
kpi_name
].
add_record
(
kpi_value
)
kpi_tracker
[
kpi_name
].
persist
()
if
__name__
==
'__main__'
:
log
=
sys
.
stdin
.
read
()
log_to_ce
(
log
)
fluid/PaddleNLP/language_model/lstm/args.py
0 → 100644
浏览文件 @
a8dc7ed3
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
distutils.util
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
"small"
,
help
=
"model_type [test|small|med|big]"
)
parser
.
add_argument
(
"--data_path"
,
type
=
str
,
help
=
"all the data for train,valid,test"
)
parser
.
add_argument
(
'--para_init'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
False
,
help
=
'whether using gpu'
)
parser
.
add_argument
(
'--log_path'
,
help
=
'path of the log file. If not set, logs are printed to console'
)
parser
.
add_argument
(
'--enable_ce'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
return
args
fluid/PaddleNLP/language_model/lstm/data/download_data.sh
0 → 100644
浏览文件 @
a8dc7ed3
wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar
-xzvf
simple-examples.tgz
fluid/PaddleNLP/language_model/lstm/lm_model.py
0 → 100644
浏览文件 @
a8dc7ed3
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle.fluid.layers
as
layers
import
paddle.fluid
as
fluid
from
paddle.fluid.layers.control_flow
import
StaticRNN
as
PaddingRNN
import
numpy
as
np
def
lm_model
(
hidden_size
,
vocab_size
,
batch_size
,
num_layers
=
2
,
num_steps
=
20
,
init_scale
=
0.1
,
dropout
=
None
):
def
padding_rnn
(
input_embedding
,
len
=
3
,
init_hidden
=
None
,
init_cell
=
None
):
weight_1_arr
=
[]
weight_2_arr
=
[]
bias_arr
=
[]
hidden_array
=
[]
cell_array
=
[]
mask_array
=
[]
for
i
in
range
(
num_layers
):
weight_1
=
layers
.
create_parameter
([
hidden_size
*
2
,
hidden_size
*
4
],
dtype
=
"float32"
,
name
=
"fc_weight1_"
+
str
(
i
),
\
default_initializer
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
))
weight_1_arr
.
append
(
weight_1
)
bias_1
=
layers
.
create_parameter
(
[
hidden_size
*
4
],
dtype
=
"float32"
,
name
=
"fc_bias1_"
+
str
(
i
),
default_initializer
=
fluid
.
initializer
.
Constant
(
0.0
))
bias_arr
.
append
(
bias_1
)
pre_hidden
=
layers
.
slice
(
init_hidden
,
axes
=
[
0
],
starts
=
[
i
],
ends
=
[
i
+
1
])
pre_cell
=
layers
.
slice
(
init_cell
,
axes
=
[
0
],
starts
=
[
i
],
ends
=
[
i
+
1
])
pre_hidden
=
layers
.
reshape
(
pre_hidden
,
shape
=
[
-
1
,
hidden_size
])
pre_cell
=
layers
.
reshape
(
pre_cell
,
shape
=
[
-
1
,
hidden_size
])
hidden_array
.
append
(
pre_hidden
)
cell_array
.
append
(
pre_cell
)
input_embedding
=
layers
.
transpose
(
input_embedding
,
perm
=
[
1
,
0
,
2
])
rnn
=
PaddingRNN
()
with
rnn
.
step
():
input
=
rnn
.
step_input
(
input_embedding
)
for
k
in
range
(
num_layers
):
pre_hidden
=
rnn
.
memory
(
init
=
hidden_array
[
k
])
pre_cell
=
rnn
.
memory
(
init
=
cell_array
[
k
])
weight_1
=
weight_1_arr
[
k
]
bias
=
bias_arr
[
k
]
nn
=
layers
.
concat
([
input
,
pre_hidden
],
1
)
gate_input
=
layers
.
matmul
(
x
=
nn
,
y
=
weight_1
)
gate_input
=
layers
.
elementwise_add
(
gate_input
,
bias
)
#i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
i
=
layers
.
slice
(
gate_input
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
hidden_size
])
j
=
layers
.
slice
(
gate_input
,
axes
=
[
1
],
starts
=
[
hidden_size
],
ends
=
[
hidden_size
*
2
])
f
=
layers
.
slice
(
gate_input
,
axes
=
[
1
],
starts
=
[
hidden_size
*
2
],
ends
=
[
hidden_size
*
3
])
o
=
layers
.
slice
(
gate_input
,
axes
=
[
1
],
starts
=
[
hidden_size
*
3
],
ends
=
[
hidden_size
*
4
])
c
=
pre_cell
*
layers
.
sigmoid
(
f
)
+
layers
.
sigmoid
(
i
)
*
layers
.
tanh
(
j
)
m
=
layers
.
tanh
(
c
)
*
layers
.
sigmoid
(
o
)
rnn
.
update_memory
(
pre_hidden
,
m
)
rnn
.
update_memory
(
pre_cell
,
c
)
rnn
.
step_output
(
m
)
rnn
.
step_output
(
c
)
input
=
m
if
dropout
!=
None
and
dropout
>
0.0
:
input
=
layers
.
dropout
(
input
,
dropout_prob
=
dropout
,
dropout_implementation
=
'upscale_in_train'
)
rnn
.
step_output
(
input
)
#real_res = layers.concat(res, 0)
rnnout
=
rnn
()
last_hidden_array
=
[]
last_cell_array
=
[]
real_res
=
rnnout
[
-
1
]
for
i
in
range
(
num_layers
):
m
=
rnnout
[
i
*
2
]
c
=
rnnout
[
i
*
2
+
1
]
m
.
stop_gradient
=
True
c
.
stop_gradient
=
True
last_h
=
layers
.
slice
(
m
,
axes
=
[
0
],
starts
=
[
num_steps
-
1
],
ends
=
[
num_steps
])
last_hidden_array
.
append
(
last_h
)
last_c
=
layers
.
slice
(
c
,
axes
=
[
0
],
starts
=
[
num_steps
-
1
],
ends
=
[
num_steps
])
last_cell_array
.
append
(
last_c
)
'''
else:
real_res = rnnout[-1]
for i in range( num_layers ):
m1, c1, m2, c2 = rnnout
real_res = m2
m1.stop_gradient = True
c1.stop_gradient = True
c2.stop_gradient = True
'''
#layers.Print( first_hidden, message="22", summarize=10)
#layers.Print( rnnout[1], message="11", summarize=10)
#real_res = ( rnnout[1] + rnnout[2] + rnnout[3] + rnnout[4]) / 4.0
real_res
=
layers
.
transpose
(
x
=
real_res
,
perm
=
[
1
,
0
,
2
])
last_hidden
=
layers
.
concat
(
last_hidden_array
,
0
)
last_cell
=
layers
.
concat
(
last_cell_array
,
0
)
'''
last_hidden = layers.concat( hidden_array, 1 )
last_hidden = layers.reshape( last_hidden, shape=[-1, num_layers, hidden_size])
last_hidden = layers.transpose( x = last_hidden, perm = [1, 0, 2])
last_cell = layers.concat( cell_array, 1)
last_cell = layers.reshape( last_cell, shape=[ -1, num_layers, hidden_size])
last_cell = layers.transpose( x = last_cell, perm = [1, 0, 2])
'''
return
real_res
,
last_hidden
,
last_cell
def
encoder_static
(
input_embedding
,
len
=
3
,
init_hidden
=
None
,
init_cell
=
None
):
weight_1_arr
=
[]
weight_2_arr
=
[]
bias_arr
=
[]
hidden_array
=
[]
cell_array
=
[]
mask_array
=
[]
for
i
in
range
(
num_layers
):
weight_1
=
layers
.
create_parameter
([
hidden_size
*
2
,
hidden_size
*
4
],
dtype
=
"float32"
,
name
=
"fc_weight1_"
+
str
(
i
),
\
default_initializer
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
))
weight_1_arr
.
append
(
weight_1
)
bias_1
=
layers
.
create_parameter
(
[
hidden_size
*
4
],
dtype
=
"float32"
,
name
=
"fc_bias1_"
+
str
(
i
),
default_initializer
=
fluid
.
initializer
.
Constant
(
0.0
))
bias_arr
.
append
(
bias_1
)
pre_hidden
=
layers
.
slice
(
init_hidden
,
axes
=
[
0
],
starts
=
[
i
],
ends
=
[
i
+
1
])
pre_cell
=
layers
.
slice
(
init_cell
,
axes
=
[
0
],
starts
=
[
i
],
ends
=
[
i
+
1
])
pre_hidden
=
layers
.
reshape
(
pre_hidden
,
shape
=
[
-
1
,
hidden_size
])
pre_cell
=
layers
.
reshape
(
pre_cell
,
shape
=
[
-
1
,
hidden_size
])
hidden_array
.
append
(
pre_hidden
)
cell_array
.
append
(
pre_cell
)
res
=
[]
for
index
in
range
(
len
):
input
=
layers
.
slice
(
input_embedding
,
axes
=
[
1
],
starts
=
[
index
],
ends
=
[
index
+
1
])
input
=
layers
.
reshape
(
input
,
shape
=
[
-
1
,
hidden_size
])
for
k
in
range
(
num_layers
):
pre_hidden
=
hidden_array
[
k
]
pre_cell
=
cell_array
[
k
]
weight_1
=
weight_1_arr
[
k
]
bias
=
bias_arr
[
k
]
nn
=
layers
.
concat
([
input
,
pre_hidden
],
1
)
gate_input
=
layers
.
matmul
(
x
=
nn
,
y
=
weight_1
)
gate_input
=
layers
.
elementwise_add
(
gate_input
,
bias
)
i
,
j
,
f
,
o
=
layers
.
split
(
gate_input
,
num_or_sections
=
4
,
dim
=-
1
)
c
=
pre_cell
*
layers
.
sigmoid
(
f
)
+
layers
.
sigmoid
(
i
)
*
layers
.
tanh
(
j
)
m
=
layers
.
tanh
(
c
)
*
layers
.
sigmoid
(
o
)
hidden_array
[
k
]
=
m
cell_array
[
k
]
=
c
input
=
m
if
dropout
!=
None
and
dropout
>
0.0
:
input
=
layers
.
dropout
(
input
,
dropout_prob
=
dropout
,
dropout_implementation
=
'upscale_in_train'
)
res
.
append
(
layers
.
reshape
(
input
,
shape
=
[
1
,
-
1
,
hidden_size
]))
real_res
=
layers
.
concat
(
res
,
0
)
real_res
=
layers
.
transpose
(
x
=
real_res
,
perm
=
[
1
,
0
,
2
])
last_hidden
=
layers
.
concat
(
hidden_array
,
1
)
last_hidden
=
layers
.
reshape
(
last_hidden
,
shape
=
[
-
1
,
num_layers
,
hidden_size
])
last_hidden
=
layers
.
transpose
(
x
=
last_hidden
,
perm
=
[
1
,
0
,
2
])
last_cell
=
layers
.
concat
(
cell_array
,
1
)
last_cell
=
layers
.
reshape
(
last_cell
,
shape
=
[
-
1
,
num_layers
,
hidden_size
])
last_cell
=
layers
.
transpose
(
x
=
last_cell
,
perm
=
[
1
,
0
,
2
])
return
real_res
,
last_hidden
,
last_cell
x
=
layers
.
data
(
name
=
"x"
,
shape
=
[
-
1
,
1
,
1
],
dtype
=
'int64'
)
y
=
layers
.
data
(
name
=
"y"
,
shape
=
[
-
1
,
1
],
dtype
=
'float32'
)
init_hidden
=
layers
.
data
(
name
=
"init_hidden"
,
shape
=
[
1
],
dtype
=
'float32'
)
init_cell
=
layers
.
data
(
name
=
"init_cell"
,
shape
=
[
1
],
dtype
=
'float32'
)
init_hidden
=
layers
.
reshape
(
init_hidden
,
shape
=
[
num_layers
,
-
1
,
hidden_size
])
init_cell
=
layers
.
reshape
(
init_cell
,
shape
=
[
num_layers
,
-
1
,
hidden_size
])
x_emb
=
layers
.
embedding
(
input
=
x
,
size
=
[
vocab_size
,
hidden_size
],
dtype
=
'float32'
,
is_sparse
=
True
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'embedding_para'
,
initializer
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
)))
x_emb
=
layers
.
reshape
(
x_emb
,
shape
=
[
-
1
,
num_steps
,
hidden_size
])
if
dropout
!=
None
and
dropout
>
0.0
:
x_emb
=
layers
.
dropout
(
x_emb
,
dropout_prob
=
dropout
,
dropout_implementation
=
'upscale_in_train'
)
rnn_out
,
last_hidden
,
last_cell
=
padding_rnn
(
x_emb
,
len
=
num_steps
,
init_hidden
=
init_hidden
,
init_cell
=
init_cell
)
rnn_out
=
layers
.
reshape
(
rnn_out
,
shape
=
[
-
1
,
num_steps
,
hidden_size
])
softmax_weight
=
layers
.
create_parameter
([
hidden_size
,
vocab_size
],
dtype
=
"float32"
,
name
=
"softmax_weight"
,
\
default_initializer
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
))
softmax_bias
=
layers
.
create_parameter
([
vocab_size
],
dtype
=
"float32"
,
name
=
'softmax_bias'
,
\
default_initializer
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
))
projection
=
layers
.
matmul
(
rnn_out
,
softmax_weight
)
projection
=
layers
.
elementwise_add
(
projection
,
softmax_bias
)
projection
=
layers
.
reshape
(
projection
,
shape
=
[
-
1
,
vocab_size
])
#y = layers.reshape( y, shape=[-1, vocab_size])
loss
=
layers
.
softmax_with_cross_entropy
(
logits
=
projection
,
label
=
y
,
soft_label
=
False
)
loss
=
layers
.
reshape
(
loss
,
shape
=
[
-
1
,
num_steps
])
loss
=
layers
.
reduce_mean
(
loss
,
dim
=
[
0
])
loss
=
layers
.
reduce_sum
(
loss
)
loss
.
permissions
=
True
feeding_list
=
[
'x'
,
'y'
,
'init_hidden'
,
'init_cell'
]
return
loss
,
last_hidden
,
last_cell
,
feeding_list
fluid/PaddleNLP/language_model/lstm/reader.py
0 → 100644
浏览文件 @
a8dc7ed3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for parsing PTB text files."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
os
import
sys
import
numpy
as
np
Py3
=
sys
.
version_info
[
0
]
==
3
def
_read_words
(
filename
):
data
=
[]
with
open
(
filename
,
"r"
)
as
f
:
return
f
.
read
().
decode
(
"utf-8"
).
replace
(
"
\n
"
,
"<eos>"
).
split
()
def
_build_vocab
(
filename
):
data
=
_read_words
(
filename
)
counter
=
collections
.
Counter
(
data
)
count_pairs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
count_pairs
))
print
(
"vocab word num"
,
len
(
words
))
word_to_id
=
dict
(
zip
(
words
,
range
(
len
(
words
))))
return
word_to_id
def
_file_to_word_ids
(
filename
,
word_to_id
):
data
=
_read_words
(
filename
)
return
[
word_to_id
[
word
]
for
word
in
data
if
word
in
word_to_id
]
def
ptb_raw_data
(
data_path
=
None
):
"""Load PTB raw data from data directory "data_path".
Reads PTB text files, converts strings to integer ids,
and performs mini-batching of the inputs.
The PTB dataset comes from Tomas Mikolov's webpage:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data, test_data, vocabulary)
where each of the data objects can be passed to PTBIterator.
"""
train_path
=
os
.
path
.
join
(
data_path
,
"ptb.train.txt"
)
#train_path = os.path.join(data_path, "train.fake")
valid_path
=
os
.
path
.
join
(
data_path
,
"ptb.valid.txt"
)
test_path
=
os
.
path
.
join
(
data_path
,
"ptb.test.txt"
)
word_to_id
=
_build_vocab
(
train_path
)
train_data
=
_file_to_word_ids
(
train_path
,
word_to_id
)
valid_data
=
_file_to_word_ids
(
valid_path
,
word_to_id
)
test_data
=
_file_to_word_ids
(
test_path
,
word_to_id
)
vocabulary
=
len
(
word_to_id
)
return
train_data
,
valid_data
,
test_data
,
vocabulary
def
get_data_iter
(
raw_data
,
batch_size
,
num_steps
):
data_len
=
len
(
raw_data
)
raw_data
=
np
.
asarray
(
raw_data
,
dtype
=
"int64"
)
#print( "raw", raw_data[:20] )
batch_len
=
data_len
//
batch_size
data
=
raw_data
[
0
:
batch_size
*
batch_len
].
reshape
((
batch_size
,
batch_len
))
#h = data.reshape( (-1))
#print( "h", h[:20])
epoch_size
=
(
batch_len
-
1
)
//
num_steps
for
i
in
range
(
epoch_size
):
start
=
i
*
num_steps
#print( i * num_steps )
x
=
np
.
copy
(
data
[:,
i
*
num_steps
:(
i
+
1
)
*
num_steps
])
y
=
np
.
copy
(
data
[:,
i
*
num_steps
+
1
:(
i
+
1
)
*
num_steps
+
1
])
yield
(
x
,
y
)
fluid/PaddleNLP/language_model/lstm/train.py
0 → 100644
浏览文件 @
a8dc7ed3
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
time
import
os
import
random
import
math
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
import
paddle.fluid.framework
as
framework
from
paddle.fluid.executor
import
Executor
import
reader
import
sys
if
sys
.
version
[
0
]
==
'2'
:
reload
(
sys
)
sys
.
setdefaultencoding
(
"utf-8"
)
sys
.
path
.
append
(
'..'
)
import
os
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
from
args
import
*
import
lm_model
import
logging
import
pickle
SEED
=
123
def
get_current_model_para
(
train_prog
,
train_exe
):
param_list
=
train_prog
.
block
(
0
).
all_parameters
()
param_name_list
=
[
p
.
name
for
p
in
param_list
]
vals
=
{}
for
p_name
in
param_name_list
:
p_array
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
p_name
).
get_tensor
())
vals
[
p_name
]
=
p_array
return
vals
def
save_para_npz
(
train_prog
,
train_exe
):
print
(
"begin to save model to model_base"
)
param_list
=
train_prog
.
block
(
0
).
all_parameters
()
param_name_list
=
[
p
.
name
for
p
in
param_list
]
vals
=
{}
for
p_name
in
param_name_list
:
p_array
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
p_name
).
get_tensor
())
vals
[
p_name
]
=
p_array
emb
=
vals
[
"embedding_para"
]
print
(
"begin to save model to model_base"
)
np
.
savez
(
"mode_base"
,
**
vals
)
def
train
():
args
=
parse_args
()
model_type
=
args
.
model_type
logger
=
logging
.
getLogger
(
"lm"
)
logger
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
if
args
.
enable_ce
:
fluid
.
default_startup_program
().
random_seed
=
SEED
if
args
.
log_path
:
file_handler
=
logging
.
FileHandler
(
args
.
log_path
)
file_handler
.
setLevel
(
logging
.
INFO
)
file_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
file_handler
)
else
:
console_handler
=
logging
.
StreamHandler
()
console_handler
.
setLevel
(
logging
.
INFO
)
console_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
console_handler
)
logger
.
info
(
'Running with args : {}'
.
format
(
args
))
vocab_size
=
10000
if
model_type
==
"test"
:
num_layers
=
1
batch_size
=
2
hidden_size
=
10
num_steps
=
3
init_scale
=
0.1
max_grad_norm
=
5.0
epoch_start_decay
=
1
max_epoch
=
1
dropout
=
0.0
lr_decay
=
0.5
base_learning_rate
=
1.0
elif
model_type
==
"small"
:
num_layers
=
2
batch_size
=
20
hidden_size
=
200
num_steps
=
20
init_scale
=
0.1
max_grad_norm
=
5.0
epoch_start_decay
=
4
max_epoch
=
13
dropout
=
0.0
lr_decay
=
0.5
base_learning_rate
=
1.0
elif
model_type
==
"medium"
:
num_layers
=
2
batch_size
=
20
hidden_size
=
650
num_steps
=
35
init_scale
=
0.05
max_grad_norm
=
5.0
epoch_start_decay
=
6
max_epoch
=
39
dropout
=
0.5
lr_decay
=
0.8
base_learning_rate
=
1.0
elif
model_type
==
"large"
:
num_layers
=
2
batch_size
=
20
hidden_size
=
1500
num_steps
=
35
init_scale
=
0.04
max_grad_norm
=
10.0
epoch_start_decay
=
14
max_epoch
=
55
dropout
=
0.65
lr_decay
=
1.0
/
1.15
base_learning_rate
=
1.0
else
:
print
(
"model type not support"
)
return
# Training process
loss
,
last_hidden
,
last_cell
,
feed_order
=
lm_model
.
lm_model
(
hidden_size
,
vocab_size
,
batch_size
,
num_layers
=
num_layers
,
num_steps
=
num_steps
,
init_scale
=
init_scale
,
dropout
=
dropout
)
# clone from default main program and use it as the validation program
main_program
=
fluid
.
default_main_program
()
inference_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
max_grad_norm
))
learning_rate
=
fluid
.
layers
.
create_global_var
(
name
=
"learning_rate"
,
shape
=
[
1
],
value
=
1.0
,
dtype
=
'float32'
,
persistable
=
True
)
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
learning_rate
)
optimizer
.
minimize
(
loss
)
place
=
core
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
core
.
CPUPlace
()
exe
=
Executor
(
place
)
exe
.
run
(
framework
.
default_startup_program
())
data_path
=
args
.
data_path
print
(
"begin to load data"
)
raw_data
=
reader
.
ptb_raw_data
(
data_path
)
print
(
"finished load data"
)
train_data
,
valid_data
,
test_data
,
_
=
raw_data
def
prepare_input
(
batch
,
init_hidden
,
init_cell
,
epoch_id
=
0
,
with_lr
=
True
):
x
,
y
=
batch
new_lr
=
base_learning_rate
*
(
lr_decay
**
max
(
epoch_id
+
1
-
epoch_start_decay
,
0.0
))
lr
=
np
.
ones
((
1
),
dtype
=
'float32'
)
*
new_lr
res
=
{}
x
=
x
.
reshape
((
-
1
,
num_steps
,
1
))
y
=
y
.
reshape
((
-
1
,
1
))
res
[
'x'
]
=
x
res
[
'y'
]
=
y
res
[
'init_hidden'
]
=
init_hidden
res
[
'init_cell'
]
=
init_cell
if
with_lr
:
res
[
'learning_rate'
]
=
lr
return
res
def
eval
(
data
):
# when eval the batch_size set to 1
eval_data_iter
=
reader
.
get_data_iter
(
data
,
1
,
num_steps
)
total_loss
=
0.0
iters
=
0
init_hidden
=
np
.
zeros
((
num_layers
,
1
,
hidden_size
),
dtype
=
'float32'
)
init_cell
=
np
.
zeros
((
num_layers
,
1
,
hidden_size
),
dtype
=
'float32'
)
for
batch_id
,
batch
in
enumerate
(
eval_data_iter
):
input_data_feed
=
prepare_input
(
batch
,
init_hidden
,
init_cell
,
epoch_id
,
with_lr
=
False
)
fetch_outs
=
exe
.
run
(
inference_program
,
feed
=
input_data_feed
,
fetch_list
=
[
loss
.
name
,
last_hidden
.
name
,
last_cell
.
name
])
cost_train
=
np
.
array
(
fetch_outs
[
0
])
init_hidden
=
np
.
array
(
fetch_outs
[
1
])
init_cell
=
np
.
array
(
fetch_outs
[
2
])
total_loss
+=
cost_train
iters
+=
num_steps
ppl
=
np
.
exp
(
total_loss
/
iters
)
return
ppl
# get train epoch size
batch_len
=
len
(
train_data
)
//
batch_size
epoch_size
=
(
batch_len
-
1
)
//
num_steps
log_interval
=
epoch_size
//
10
total_time
=
0.0
for
epoch_id
in
range
(
max_epoch
):
start_time
=
time
.
time
()
print
(
"epoch id"
,
epoch_id
)
train_data_iter
=
reader
.
get_data_iter
(
train_data
,
batch_size
,
num_steps
)
total_loss
=
0
init_hidden
=
None
init_cell
=
None
#debug_para(fluid.framework.default_main_program(), parallel_executor)
total_loss
=
0
iters
=
0
init_hidden
=
np
.
zeros
(
(
num_layers
,
batch_size
,
hidden_size
),
dtype
=
'float32'
)
init_cell
=
np
.
zeros
(
(
num_layers
,
batch_size
,
hidden_size
),
dtype
=
'float32'
)
for
batch_id
,
batch
in
enumerate
(
train_data_iter
):
input_data_feed
=
prepare_input
(
batch
,
init_hidden
,
init_cell
,
epoch_id
=
epoch_id
)
fetch_outs
=
exe
.
run
(
feed
=
input_data_feed
,
fetch_list
=
[
loss
.
name
,
last_hidden
.
name
,
last_cell
.
name
,
'learning_rate'
])
cost_train
=
np
.
array
(
fetch_outs
[
0
])
init_hidden
=
np
.
array
(
fetch_outs
[
1
])
init_cell
=
np
.
array
(
fetch_outs
[
2
])
lr
=
np
.
array
(
fetch_outs
[
3
])
total_loss
+=
cost_train
iters
+=
num_steps
if
batch_id
>
0
and
batch_id
%
log_interval
==
0
:
ppl
=
np
.
exp
(
total_loss
/
iters
)
print
(
"ppl "
,
batch_id
,
ppl
[
0
],
lr
[
0
])
ppl
=
np
.
exp
(
total_loss
/
iters
)
if
epoch_id
==
0
and
ppl
[
0
]
>
1000
:
# for bad init, after first epoch, the loss is over 1000
# no more need to continue
return
end_time
=
time
.
time
()
total_time
+=
end_time
-
start_time
print
(
"train ppl"
,
ppl
[
0
])
if
epoch_id
==
max_epoch
-
1
and
args
.
enable_ce
:
print
(
"lstm_language_model_duration
\t
%s"
%
(
total_time
/
max_epoch
))
print
(
"lstm_language_model_loss
\t
%s"
%
ppl
[
0
])
model_path
=
os
.
path
.
join
(
"model_new/"
,
str
(
epoch_id
))
if
not
os
.
path
.
isdir
(
model_path
):
os
.
makedirs
(
model_path
)
fluid
.
io
.
save_persistables
(
executor
=
exe
,
dirname
=
model_path
,
main_program
=
main_program
)
valid_ppl
=
eval
(
valid_data
)
print
(
"valid ppl"
,
valid_ppl
[
0
])
test_ppl
=
eval
(
test_data
)
print
(
"test ppl"
,
test_ppl
[
0
])
if
__name__
==
'__main__'
:
train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录