Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
54b3b726
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
54b3b726
编写于
4月 28, 2020
作者:
W
wangxiao1021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add emotion_detection and update senta
上级
cea791c6
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
772 addition
and
36 deletion
+772
-36
examples/emotion_detection/config.yaml
examples/emotion_detection/config.yaml
+23
-0
examples/emotion_detection/download.py
examples/emotion_detection/download.py
+123
-0
examples/emotion_detection/download_data.sh
examples/emotion_detection/download_data.sh
+8
-0
examples/emotion_detection/models.py
examples/emotion_detection/models.py
+179
-0
examples/emotion_detection/run_classifier.py
examples/emotion_detection/run_classifier.py
+158
-0
examples/sentiment_classification/models.py
examples/sentiment_classification/models.py
+33
-7
examples/sentiment_classification/sentiment_classifier.py
examples/sentiment_classification/sentiment_classifier.py
+27
-28
hapi/text/emo_tect/__init__.py
hapi/text/emo_tect/__init__.py
+15
-0
hapi/text/emo_tect/data_processor.py
hapi/text/emo_tect/data_processor.py
+79
-0
hapi/text/emo_tect/data_reader.py
hapi/text/emo_tect/data_reader.py
+126
-0
hapi/text/senta/__init__.py
hapi/text/senta/__init__.py
+1
-1
hapi/text/senta/data_processor.py
hapi/text/senta/data_processor.py
+0
-0
未找到文件。
examples/emotion_detection/config.yaml
0 → 100644
浏览文件 @
54b3b726
model_type
:
"
bow_net"
num_labels
:
3
vocab_size
:
240465
vocab_path
:
"
./data/vocab.txt"
data_dir
:
"
./data"
inference_model_dir
:
"
./inference_model"
save_checkpoint_dir
:
"
"
init_checkpoint
:
"
"
checkpoints
:
"
./checkpoints/"
lr
:
0.02
epoch
:
10
batch_size
:
24
do_train
:
True
do_val
:
True
do_infer
:
False
do_save_inference_model
:
False
max_seq_len
:
20
skip_steps
:
10
save_freq
:
1
eval_freq
:
1
random_seed
:
0
output_dir
:
"
./output"
use_cuda
:
True
examples/emotion_detection/download.py
0 → 100644
浏览文件 @
54b3b726
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
io
import
os
import
sys
import
time
import
hashlib
import
tarfile
import
requests
def
usage
():
desc
=
(
"
\n
Download datasets and pretrained models for EmotionDetection task.
\n
"
"Usage:
\n
"
" python download.py dataset
\n
"
)
print
(
desc
)
def
md5file
(
fname
):
hash_md5
=
hashlib
.
md5
()
with
io
.
open
(
fname
,
"rb"
)
as
fin
:
for
chunk
in
iter
(
lambda
:
fin
.
read
(
4096
),
b
""
):
hash_md5
.
update
(
chunk
)
return
hash_md5
.
hexdigest
()
def
extract
(
fname
,
dir_path
):
"""
Extract tar.gz file
"""
try
:
tar
=
tarfile
.
open
(
fname
,
"r:gz"
)
file_names
=
tar
.
getnames
()
for
file_name
in
file_names
:
tar
.
extract
(
file_name
,
dir_path
)
print
(
file_name
)
tar
.
close
()
except
Exception
as
e
:
raise
e
def
download
(
url
,
filename
,
md5sum
):
"""
Download file and check md5
"""
retry
=
0
retry_limit
=
3
chunk_size
=
4096
while
not
(
os
.
path
.
exists
(
filename
)
and
md5file
(
filename
)
==
md5sum
):
if
retry
<
retry_limit
:
retry
+=
1
else
:
raise
RuntimeError
(
"Cannot download dataset ({0}) with retry {1} times."
.
format
(
url
,
retry_limit
))
try
:
start
=
time
.
time
()
size
=
0
res
=
requests
.
get
(
url
,
stream
=
True
)
filesize
=
int
(
res
.
headers
[
'content-length'
])
if
res
.
status_code
==
200
:
print
(
"[Filesize]: %0.2f MB"
%
(
filesize
/
1024
/
1024
))
# save by chunk
with
io
.
open
(
filename
,
"wb"
)
as
fout
:
for
chunk
in
res
.
iter_content
(
chunk_size
=
chunk_size
):
if
chunk
:
fout
.
write
(
chunk
)
size
+=
len
(
chunk
)
pr
=
'>'
*
int
(
size
*
50
/
filesize
)
print
(
'
\r
[Process ]: %s%.2f%%'
%
(
pr
,
float
(
size
/
filesize
*
100
)),
end
=
''
)
end
=
time
.
time
()
print
(
"
\n
[CostTime]: %.2f s"
%
(
end
-
start
))
except
Exception
as
e
:
print
(
e
)
def
download_dataset
(
dir_path
):
BASE_URL
=
"https://baidu-nlp.bj.bcebos.com/"
DATASET_NAME
=
"emotion_detection-dataset-1.0.0.tar.gz"
DATASET_MD5
=
"512d256add5f9ebae2c101b74ab053e9"
file_path
=
os
.
path
.
join
(
dir_path
,
DATASET_NAME
)
url
=
BASE_URL
+
DATASET_NAME
if
not
os
.
path
.
exists
(
dir_path
):
os
.
makedirs
(
dir_path
)
# download dataset
print
(
"Downloading dataset: %s"
%
url
)
download
(
url
,
file_path
,
DATASET_MD5
)
# extract dataset
print
(
"Extracting dataset: %s"
%
file_path
)
extract
(
file_path
,
dir_path
)
os
.
remove
(
file_path
)
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
!=
2
:
usage
()
sys
.
exit
(
1
)
if
sys
.
argv
[
1
]
==
"dataset"
:
pwd
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'./'
)
download_dataset
(
pwd
)
else
:
usage
()
examples/emotion_detection/download_data.sh
0 → 100644
浏览文件 @
54b3b726
#!/bin/bash
# download dataset file to ./data/
DATA_URL
=
https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz
wget
--no-check-certificate
${
DATA_URL
}
tar
xvf emotion_detection-dataset-1.0.0.tar.gz
/bin/rm emotion_detection-dataset-1.0.0.tar.gz
examples/emotion_detection/models.py
0 → 100644
浏览文件 @
54b3b726
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Linear
,
Embedding
from
paddle.fluid.dygraph.base
import
to_variable
import
numpy
as
np
from
hapi.model
import
Model
from
hapi.text.text
import
GRUEncoderLayer
as
BiGRUEncoder
from
hapi.text.text
import
BOWEncoder
,
CNNEncoder
,
GRUEncoder
,
LSTMEncoder
class
CNN
(
Model
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
CNN
,
self
).
__init__
()
self
.
dict_dim
=
dict_dim
self
.
emb_dim
=
128
self
.
hid_dim
=
128
self
.
fc_hid_dim
=
96
self
.
class_dim
=
3
self
.
channels
=
1
self
.
win_size
=
[
3
,
self
.
hid_dim
]
self
.
seq_len
=
seq_len
self
.
_encoder
=
CNNEncoder
(
dict_size
=
self
.
dict_dim
+
1
,
emb_dim
=
self
.
emb_dim
,
seq_len
=
self
.
seq_len
,
filter_size
=
self
.
win_size
,
num_filters
=
self
.
hid_dim
,
hidden_dim
=
self
.
hid_dim
,
padding_idx
=
None
,
act
=
'tanh'
)
self
.
_fc1
=
Linear
(
input_dim
=
self
.
hid_dim
*
self
.
seq_len
,
output_dim
=
self
.
fc_hid_dim
,
act
=
"softmax"
)
self
.
_fc_prediction
=
Linear
(
input_dim
=
self
.
fc_hid_dim
,
output_dim
=
self
.
class_dim
,
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
conv_3
=
self
.
_encoder
(
inputs
)
fc_1
=
self
.
_fc1
(
conv_3
)
prediction
=
self
.
_fc_prediction
(
fc_1
)
return
prediction
class
BOW
(
Model
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
BOW
,
self
).
__init__
()
self
.
dict_dim
=
dict_dim
self
.
emb_dim
=
128
self
.
hid_dim
=
128
self
.
fc_hid_dim
=
96
self
.
class_dim
=
3
self
.
seq_len
=
seq_len
self
.
_encoder
=
BOWEncoder
(
dict_size
=
self
.
dict_dim
+
1
,
emb_dim
=
self
.
emb_dim
,
padding_idx
=
None
,
bow_dim
=
self
.
hid_dim
,
seq_len
=
self
.
seq_len
)
self
.
_fc1
=
Linear
(
input_dim
=
self
.
hid_dim
,
output_dim
=
self
.
hid_dim
,
act
=
"tanh"
)
self
.
_fc2
=
Linear
(
input_dim
=
self
.
hid_dim
,
output_dim
=
self
.
fc_hid_dim
,
act
=
"tanh"
)
self
.
_fc_prediction
=
Linear
(
input_dim
=
self
.
fc_hid_dim
,
output_dim
=
self
.
class_dim
,
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
bow_1
=
self
.
_encoder
(
inputs
)
bow_1
=
fluid
.
layers
.
tanh
(
bow_1
)
fc_1
=
self
.
_fc1
(
bow_1
)
fc_2
=
self
.
_fc2
(
fc_1
)
prediction
=
self
.
_fc_prediction
(
fc_2
)
return
prediction
class
GRU
(
Model
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
GRU
,
self
).
__init__
()
self
.
dict_dim
=
dict_dim
self
.
emb_dim
=
128
self
.
hid_dim
=
128
self
.
fc_hid_dim
=
96
self
.
class_dim
=
3
self
.
seq_len
=
seq_len
self
.
_fc1
=
Linear
(
input_dim
=
self
.
hid_dim
,
output_dim
=
self
.
fc_hid_dim
,
act
=
"tanh"
)
self
.
_fc_prediction
=
Linear
(
input_dim
=
self
.
fc_hid_dim
,
output_dim
=
self
.
class_dim
,
act
=
"softmax"
)
self
.
_encoder
=
GRUEncoder
(
dict_size
=
self
.
dict_dim
+
1
,
emb_dim
=
self
.
emb_dim
,
gru_dim
=
self
.
hid_dim
,
hidden_dim
=
self
.
hid_dim
,
padding_idx
=
None
,
seq_len
=
self
.
seq_len
)
def
forward
(
self
,
inputs
):
emb
=
self
.
_encoder
(
inputs
)
fc_1
=
self
.
_fc1
(
emb
)
prediction
=
self
.
_fc_prediction
(
fc_1
)
return
prediction
class
BiGRU
(
Model
):
def
__init__
(
self
,
dict_dim
,
batch_size
,
seq_len
):
super
(
BiGRU
,
self
).
__init__
()
self
.
dict_dim
=
dict_dim
self
.
emb_dim
=
128
self
.
hid_dim
=
128
self
.
fc_hid_dim
=
96
self
.
class_dim
=
3
self
.
batch_size
=
batch_size
self
.
seq_len
=
seq_len
self
.
embedding
=
Embedding
(
size
=
[
self
.
dict_dim
+
1
,
self
.
emb_dim
],
dtype
=
'float32'
,
param_attr
=
fluid
.
ParamAttr
(
learning_rate
=
30
),
is_sparse
=
False
)
h_0
=
np
.
zeros
((
self
.
batch_size
,
self
.
hid_dim
),
dtype
=
"float32"
)
h_0
=
to_variable
(
h_0
)
self
.
_fc1
=
Linear
(
input_dim
=
self
.
hid_dim
,
output_dim
=
self
.
hid_dim
*
3
)
self
.
_fc2
=
Linear
(
input_dim
=
self
.
hid_dim
*
2
,
output_dim
=
self
.
fc_hid_dim
,
act
=
"tanh"
)
self
.
_fc_prediction
=
Linear
(
input_dim
=
self
.
fc_hid_dim
,
output_dim
=
self
.
class_dim
,
act
=
"softmax"
)
self
.
_encoder
=
BiGRUEncoder
(
grnn_hidden_dim
=
self
.
hid_dim
,
input_dim
=
self
.
hid_dim
*
3
,
h_0
=
h_0
,
init_bound
=
0.1
,
is_bidirection
=
True
)
def
forward
(
self
,
inputs
):
emb
=
self
.
embedding
(
inputs
)
emb
=
fluid
.
layers
.
reshape
(
emb
,
shape
=
[
self
.
batch_size
,
-
1
,
self
.
hid_dim
])
fc_1
=
self
.
_fc1
(
emb
)
encoded_vector
=
self
.
_encoder
(
fc_1
)
encoded_vector
=
fluid
.
layers
.
tanh
(
encoded_vector
)
encoded_vector
=
fluid
.
layers
.
reduce_max
(
encoded_vector
,
dim
=
1
)
fc_2
=
self
.
_fc2
(
encoded_vector
)
prediction
=
self
.
_fc_prediction
(
fc_2
)
return
prediction
class
LSTM
(
Model
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
LSTM
,
self
).
__init__
()
self
.
seq_len
=
seq_len
,
self
.
dict_dim
=
dict_dim
,
self
.
emb_dim
=
128
,
self
.
hid_dim
=
128
,
self
.
fc_hid_dim
=
96
,
self
.
class_dim
=
3
,
self
.
emb_lr
=
30.0
,
self
.
_encoder
=
LSTMEncoder
(
dict_size
=
dict_dim
+
1
,
emb_dim
=
self
.
emb_dim
,
lstm_dim
=
self
.
hid_dim
,
hidden_dim
=
self
.
hid_dim
,
seq_len
=
self
.
seq_len
,
padding_idx
=
None
,
is_reverse
=
False
)
self
.
_fc1
=
Linear
(
input_dim
=
self
.
hid_dim
,
output_dim
=
self
.
fc_hid_dim
,
act
=
"tanh"
)
self
.
_fc_prediction
=
Linear
(
input_dim
=
self
.
fc_hid_dim
,
output_dim
=
self
.
class_dim
,
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
emb
=
self
.
_encoder
(
inputs
)
fc_1
=
self
.
_fc1
(
emb
)
prediction
=
self
.
_fc_prediction
(
fc_1
)
return
prediction
examples/emotion_detection/run_classifier.py
0 → 100644
浏览文件 @
54b3b726
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Emotion Detection Task in Paddle Dygraph Mode.
"""
from
__future__
import
print_function
import
os
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
from
hapi.model
import
set_device
,
CrossEntropy
,
Input
from
hapi.metrics
import
Accuracy
from
hapi.text.emo_tect
import
EmoTectProcessor
from
models
import
CNN
,
BOW
,
GRU
,
BiGRU
,
LSTM
from
hapi.configure
import
Config
import
json
def
main
():
"""
Main Function
"""
args
=
Config
(
yaml_file
=
'./config.yaml'
)
args
.
build
()
args
.
Print
()
if
not
(
args
.
do_train
or
args
.
do_val
or
args
.
do_infer
):
raise
ValueError
(
"For args `do_train`, `do_val` and `do_infer`, at "
"least one of them must be True."
)
place
=
set_device
(
"gpu"
if
args
.
use_cuda
else
"cpu"
)
fluid
.
enable_dygraph
(
place
)
processor
=
EmoTectProcessor
(
data_dir
=
args
.
data_dir
,
vocab_path
=
args
.
vocab_path
,
random_seed
=
args
.
random_seed
)
num_labels
=
args
.
num_labels
if
args
.
model_type
==
'cnn_net'
:
model
=
CNN
(
args
.
vocab_size
,
args
.
max_seq_len
)
elif
args
.
model_type
==
'bow_net'
:
model
=
BOW
(
args
.
vocab_size
,
args
.
max_seq_len
)
elif
args
.
model_type
==
'lstm_net'
:
model
=
LSTM
(
args
.
vocab_size
,
args
.
max_seq_len
)
elif
args
.
model_type
==
'gru_net'
:
model
=
GRU
(
args
.
vocab_size
,
args
.
max_seq_len
)
elif
args
.
model_type
==
'bigru_net'
:
model
=
BiGRU
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
max_seq_len
)
else
:
raise
ValueError
(
"Unknown model type!"
)
inputs
=
[
Input
([
None
,
args
.
max_seq_len
],
'int64'
,
name
=
'doc'
)]
optimizer
=
None
labels
=
None
if
args
.
do_train
:
train_data_generator
=
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
places
=
place
,
phase
=
'train'
,
epoch
=
args
.
epoch
,
padding_size
=
args
.
max_seq_len
)
num_train_examples
=
processor
.
get_num_examples
(
phase
=
"train"
)
max_train_steps
=
args
.
epoch
*
num_train_examples
//
args
.
batch_size
+
1
print
(
"Num train examples: %d"
%
num_train_examples
)
print
(
"Max train steps: %d"
%
max_train_steps
)
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
args
.
lr
,
parameter_list
=
model
.
parameters
())
test_data_generator
=
None
if
args
.
do_val
:
test_data_generator
=
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
places
=
place
,
padding_size
=
args
.
max_seq_len
)
elif
args
.
do_val
:
test_data_generator
=
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'test'
,
epoch
=
1
,
places
=
place
,
padding_size
=
args
.
max_seq_len
)
elif
args
.
do_infer
:
infer_data_generator
=
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'infer'
,
epoch
=
1
,
places
=
place
,
padding_size
=
args
.
max_seq_len
)
model
.
prepare
(
optimizer
,
CrossEntropy
(),
Accuracy
(
topk
=
(
1
,)),
inputs
,
labels
,
device
=
place
)
if
args
.
do_train
:
if
args
.
init_checkpoint
:
model
.
load
(
args
.
init_checkpoint
)
elif
args
.
do_val
or
args
.
do_infer
:
if
not
args
.
init_checkpoint
:
raise
ValueError
(
"args 'init_checkpoint' should be set if"
"only doing validation or infer!"
)
model
.
load
(
args
.
init_checkpoint
,
reset_optimizer
=
True
)
if
args
.
do_train
:
model
.
fit
(
train_data
=
train_data_generator
,
eval_data
=
test_data_generator
,
batch_size
=
args
.
batch_size
,
epochs
=
args
.
epoch
,
save_dir
=
args
.
checkpoints
,
eval_freq
=
args
.
eval_freq
,
save_freq
=
args
.
save_freq
)
elif
args
.
do_val
:
eval_result
=
model
.
evaluate
(
eval_data
=
test_data_generator
,
batch_size
=
args
.
batch_size
)
print
(
"Final eval result: acc: {:.4f}, loss: {:.4f}"
.
format
(
eval_result
[
'acc'
],
eval_result
[
'loss'
][
0
]))
elif
args
.
do_infer
:
preds
=
model
.
predict
(
test_data
=
infer_data_generator
)
preds
=
np
.
array
(
preds
[
0
]).
reshape
((
-
1
,
args
.
num_labels
))
if
args
.
output_dir
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
'predictions.json'
),
'w'
)
as
w
:
for
p
in
range
(
len
(
preds
)):
label
=
np
.
argmax
(
preds
[
p
])
result
=
json
.
dumps
({
'index'
:
p
,
'label'
:
label
,
'probs'
:
preds
[
p
].
tolist
()})
w
.
write
(
result
+
'
\n
'
)
print
(
'Predictions saved at '
+
os
.
path
.
join
(
args
.
output_dir
,
'predictions.json'
))
if
__name__
==
"__main__"
:
main
()
examples/sentiment_classification/models.py
浏览文件 @
54b3b726
...
...
@@ -17,11 +17,11 @@ from paddle.fluid.dygraph.base import to_variable
import
numpy
as
np
from
hapi.model
import
Model
from
hapi.text.text
import
GRUEncoderLayer
as
BiGRUEncoder
from
hapi.text.te
st
import
BOWEncoder
,
CNNEncoder
,
GRU
Encoder
from
hapi.text.te
xt
import
BOWEncoder
,
CNNEncoder
,
GRUEncoder
,
LSTM
Encoder
class
CNN
(
Model
):
def
__init__
(
self
,
dict_dim
,
batch_size
,
seq_len
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
CNN
,
self
).
__init__
()
self
.
dict_dim
=
dict_dim
self
.
emb_dim
=
128
...
...
@@ -30,7 +30,6 @@ class CNN(Model):
self
.
class_dim
=
2
self
.
channels
=
1
self
.
win_size
=
[
3
,
self
.
hid_dim
]
self
.
batch_size
=
batch_size
self
.
seq_len
=
seq_len
self
.
_encoder
=
CNNEncoder
(
dict_size
=
self
.
dict_dim
+
1
,
...
...
@@ -54,14 +53,13 @@ class CNN(Model):
class
BOW
(
Model
):
def
__init__
(
self
,
dict_dim
,
batch_size
,
seq_len
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
BOW
,
self
).
__init__
()
self
.
dict_dim
=
dict_dim
self
.
emb_dim
=
128
self
.
hid_dim
=
128
self
.
fc_hid_dim
=
96
self
.
class_dim
=
2
self
.
batch_size
=
batch_size
self
.
seq_len
=
seq_len
self
.
_encoder
=
BOWEncoder
(
dict_size
=
self
.
dict_dim
+
1
,
...
...
@@ -85,14 +83,13 @@ class BOW(Model):
class
GRU
(
Model
):
def
__init__
(
self
,
dict_dim
,
batch_size
,
seq_len
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
GRU
,
self
).
__init__
()
self
.
dict_dim
=
dict_dim
self
.
emb_dim
=
128
self
.
hid_dim
=
128
self
.
fc_hid_dim
=
96
self
.
class_dim
=
2
self
.
batch_size
=
batch_size
self
.
seq_len
=
seq_len
self
.
_fc1
=
Linear
(
input_dim
=
self
.
hid_dim
,
output_dim
=
self
.
fc_hid_dim
,
act
=
"tanh"
)
self
.
_fc_prediction
=
Linear
(
input_dim
=
self
.
fc_hid_dim
,
...
...
@@ -152,3 +149,32 @@ class BiGRU(Model):
fc_2
=
self
.
_fc2
(
encoded_vector
)
prediction
=
self
.
_fc_prediction
(
fc_2
)
return
prediction
class
LSTM
(
Model
):
def
__init__
(
self
,
dict_dim
,
seq_len
):
super
(
LSTM
,
self
).
__init__
()
self
.
seq_len
=
seq_len
,
self
.
dict_dim
=
dict_dim
,
self
.
emb_dim
=
128
,
self
.
hid_dim
=
128
,
self
.
fc_hid_dim
=
96
,
self
.
class_dim
=
2
,
self
.
emb_lr
=
30.0
,
self
.
_encoder
=
LSTMEncoder
(
dict_size
=
dict_dim
+
1
,
emb_dim
=
self
.
emb_dim
,
lstm_dim
=
self
.
hid_dim
,
hidden_dim
=
self
.
hid_dim
,
seq_len
=
self
.
seq_len
,
padding_idx
=
None
,
is_reverse
=
False
)
self
.
_fc1
=
Linear
(
input_dim
=
self
.
hid_dim
,
output_dim
=
self
.
fc_hid_dim
,
act
=
"tanh"
)
self
.
_fc_prediction
=
Linear
(
input_dim
=
self
.
fc_hid_dim
,
output_dim
=
self
.
class_dim
,
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
emb
=
self
.
_encoder
(
inputs
)
fc_1
=
self
.
_fc1
(
emb
)
prediction
=
self
.
_fc_prediction
(
fc_1
)
return
prediction
examples/sentiment_classification/sentiment_classifier.py
浏览文件 @
54b3b726
...
...
@@ -17,11 +17,11 @@
from
__future__
import
print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
from
hapi.model
import
set_device
,
Model
,
CrossEntropy
,
Input
from
hapi.model
import
set_device
,
CrossEntropy
,
Input
from
hapi.configure
import
Config
from
hapi.text.senta
import
SentaProcessor
from
hapi.metrics
import
Accuracy
from
models
import
CNN
,
BOW
,
GRU
,
BiGRU
from
models
import
CNN
,
BOW
,
GRU
,
BiGRU
,
LSTM
import
json
import
os
...
...
@@ -38,6 +38,26 @@ def main():
elif
args
.
do_infer
:
infer
()
def
create_model
():
if
args
.
model_type
==
'cnn_net'
:
model
=
CNN
(
args
.
vocab_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'bow_net'
:
model
=
BOW
(
args
.
vocab_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'lstm_net'
:
model
=
LSTM
(
args
.
vocab_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'gru_net'
:
model
=
GRU
(
args
.
vocab_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'bigru_net'
:
model
=
BiGRU
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
else
:
raise
ValueError
(
"Unknown model type!"
)
return
model
def
train
():
fluid
.
enable_dygraph
(
device
)
processor
=
SentaProcessor
(
...
...
@@ -65,24 +85,14 @@ def train():
phase
=
'dev'
,
epoch
=
args
.
epoch
,
shuffle
=
False
)
if
args
.
model_type
==
'cnn_net'
:
model
=
CNN
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'bow_net'
:
model
=
BOW
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'gru_net'
:
model
=
GRU
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'bigru_net'
:
model
=
BiGRU
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
args
.
lr
,
parameter_list
=
model
.
parameters
())
inputs
=
[
Input
([
None
,
None
],
'int64'
,
name
=
'doc'
)]
labels
=
[
Input
([
None
,
1
],
'int64'
,
name
=
'label'
)]
model
=
create_model
()
optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
args
.
lr
,
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optimizer
,
CrossEntropy
(),
...
...
@@ -113,19 +123,8 @@ def infer():
phase
=
'infer'
,
epoch
=
1
,
shuffle
=
False
)
if
args
.
model_type
==
'cnn_net'
:
model_infer
=
CNN
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'bow_net'
:
model_infer
=
BOW
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'gru_net'
:
model_infer
=
GRU
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
elif
args
.
model_type
==
'bigru_net'
:
model_infer
=
BiGRU
(
args
.
vocab_size
,
args
.
batch_size
,
args
.
padding_size
)
model_infer
=
create_model
()
print
(
'Do inferring ...... '
)
inputs
=
[
Input
([
None
,
None
],
'int64'
,
name
=
'doc'
)]
model_infer
.
prepare
(
...
...
hapi/text/emo_tect/__init__.py
0 → 100644
浏览文件 @
54b3b726
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
hapi.text.emo_tect.data_processor
import
EmoTectProcessor
hapi/text/emo_tect/data_processor.py
0 → 100644
浏览文件 @
54b3b726
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
hapi.text.emo_tect.data_reader
import
load_vocab
from
hapi.text.emo_tect.data_reader
import
data_reader
from
paddle.io
import
DataLoader
class
EmoTectProcessor
(
object
):
def
__init__
(
self
,
data_dir
,
vocab_path
,
random_seed
=
None
):
self
.
data_dir
=
data_dir
self
.
vocab
=
load_vocab
(
vocab_path
)
self
.
num_examples
=
{
"train"
:
-
1
,
"dev"
:
-
1
,
"test"
:
-
1
,
"infer"
:
-
1
}
np
.
random
.
seed
(
random_seed
)
def
get_train_examples
(
self
,
data_dir
,
epoch
,
shuffle
,
batch_size
,
places
,
padding_size
):
train_reader
=
data_reader
((
self
.
data_dir
+
"/train.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"train"
,
epoch
,
padding_size
,
shuffle
)
loader
=
DataLoader
.
from_generator
(
capacity
=
50
,
return_list
=
True
)
loader
.
set_sample_generator
(
train_reader
,
batch_size
=
batch_size
,
drop_last
=
False
,
places
=
places
)
return
loader
def
get_dev_examples
(
self
,
data_dir
,
epoch
,
shuffle
,
batch_size
,
places
,
padding_size
):
dev_reader
=
data_reader
((
self
.
data_dir
+
"/dev.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"dev"
,
epoch
,
padding_size
,
shuffle
)
loader
=
DataLoader
.
from_generator
(
capacity
=
50
,
return_list
=
True
)
loader
.
set_sample_generator
(
dev_reader
,
batch_size
=
batch_size
,
drop_last
=
False
,
places
=
places
)
return
loader
def
get_test_examples
(
self
,
data_dir
,
epoch
,
batch_size
,
places
,
padding_size
):
test_reader
=
data_reader
((
self
.
data_dir
+
"/test.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"test"
,
epoch
,
padding_size
)
loader
=
DataLoader
.
from_generator
(
capacity
=
50
,
return_list
=
True
)
loader
.
set_sample_generator
(
test_reader
,
batch_size
=
batch_size
,
drop_last
=
False
,
places
=
places
)
return
loader
def
get_infer_examples
(
self
,
data_dir
,
epoch
,
batch_size
,
places
,
padding_size
):
infer_reader
=
data_reader
((
self
.
data_dir
+
"/infer.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"infer"
,
epoch
,
padding_size
)
loader
=
DataLoader
.
from_generator
(
capacity
=
50
,
return_list
=
True
)
loader
.
set_sample_generator
(
infer_reader
,
batch_size
=
batch_size
,
drop_last
=
False
,
places
=
places
)
return
loader
def
get_labels
(
self
):
return
[
"0"
,
"1"
,
"2"
]
def
get_num_examples
(
self
,
phase
):
if
phase
not
in
[
'train'
,
'dev'
,
'test'
,
'infer'
]:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'infer']."
)
return
self
.
num_examples
[
phase
]
def
get_train_progress
(
self
):
return
self
.
current_train_example
,
self
.
current_train_epoch
def
data_generator
(
self
,
padding_size
,
batch_size
,
places
,
phase
=
'train'
,
epoch
=
1
,
shuffle
=
True
):
if
phase
==
"train"
:
return
self
.
get_train_examples
(
self
.
data_dir
,
epoch
,
shuffle
,
batch_size
,
places
,
padding_size
)
elif
phase
==
"dev"
:
return
self
.
get_dev_examples
(
self
.
data_dir
,
epoch
,
shuffle
,
batch_size
,
places
,
padding_size
)
elif
phase
==
"test"
:
return
self
.
get_test_examples
(
self
.
data_dir
,
epoch
,
batch_size
,
places
,
padding_size
)
elif
phase
==
"infer"
:
return
self
.
get_infer_examples
(
self
.
data_dir
,
epoch
,
batch_size
,
places
,
padding_size
)
else
:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'infer']."
)
hapi/text/emo_tect/data_reader.py
0 → 100644
浏览文件 @
54b3b726
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
io
import
os
import
sys
import
six
import
random
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
def
word2id
(
word_dict
,
query
):
"""
Convert word sequence into id list
"""
unk_id
=
len
(
word_dict
)
wids
=
[
word_dict
[
w
]
if
w
in
word_dict
else
unk_id
for
w
in
query
.
strip
().
split
(
" "
)
]
return
wids
def
pad_wid
(
wids
,
max_seq_len
=
128
,
pad_id
=
0
):
"""
Padding data to max_seq_len
"""
seq_len
=
len
(
wids
)
if
seq_len
<
max_seq_len
:
for
i
in
range
(
max_seq_len
-
seq_len
):
wids
.
append
(
pad_id
)
else
:
wids
=
wids
[:
max_seq_len
]
return
wids
def
data_reader
(
file_path
,
word_dict
,
num_examples
,
phase
,
epoch
,
max_seq_len
,
shuffle
=
False
):
"""
Data reader, which convert word sequence into id list
"""
unk_id
=
len
(
word_dict
)
all_data
=
[]
with
io
.
open
(
file_path
,
"r"
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
if
line
.
startswith
(
"label"
):
continue
if
phase
==
"infer"
:
cols
=
line
.
strip
().
split
(
"
\t
"
)
query
=
cols
[
-
1
]
if
len
(
cols
)
!=
-
1
else
cols
[
0
]
wids
=
word2id
(
word_dict
,
query
)
wids
=
pad_wid
(
wids
,
max_seq_len
,
unk_id
)
all_data
.
append
((
wids
))
else
:
cols
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
cols
)
!=
2
:
sys
.
stderr
.
write
(
"[NOTICE] Error Format Line!"
)
continue
label
=
[
int
(
cols
[
0
])]
query
=
cols
[
1
].
strip
()
wids
=
word2id
(
word_dict
,
query
)
wids
=
pad_wid
(
wids
,
max_seq_len
,
unk_id
)
all_data
.
append
((
wids
,
label
))
num_examples
[
phase
]
=
len
(
all_data
)
if
phase
==
"infer"
:
def
reader
():
"""
Infer reader function
"""
for
wids
in
all_data
:
yield
wids
return
reader
def
reader
():
"""
Reader function
"""
for
idx
in
range
(
epoch
):
if
phase
==
"train"
and
shuffle
:
random
.
shuffle
(
all_data
)
for
wids
,
label
in
all_data
:
yield
wids
,
label
return
reader
def
load_vocab
(
file_path
):
"""
load the given vocabulary
"""
vocab
=
{}
with
io
.
open
(
file_path
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
wid
=
0
for
line
in
fin
:
if
line
.
strip
()
not
in
vocab
:
vocab
[
line
.
strip
()]
=
wid
wid
+=
1
vocab
[
"<unk>"
]
=
len
(
vocab
)
return
vocab
def
query2ids
(
vocab_path
,
query
):
"""
Convert query to id list according to the given vocab
"""
vocab
=
load_vocab
(
vocab_path
)
wids
=
word2id
(
vocab
,
query
)
return
wids
hapi/text/senta/__init__.py
浏览文件 @
54b3b726
...
...
@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
hapi.text.senta.data_process
e
r
import
SentaProcessor
from
hapi.text.senta.data_process
o
r
import
SentaProcessor
hapi/text/senta/data_process
e
r.py
→
hapi/text/senta/data_process
o
r.py
浏览文件 @
54b3b726
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录