Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDocCN
nlp-pytorch-zh
提交
0303e3d7
N
nlp-pytorch-zh
项目概览
OpenDocCN
/
nlp-pytorch-zh
通知
17
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
N
nlp-pytorch-zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
0303e3d7
编写于
7月 30, 2019
作者:
片刻小哥哥
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
更新QQ信息处理,形成对话
上级
ebf3b702
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
203 addition
and
41 deletion
+203
-41
src/Chinese_ChatBot/QQ_ETL.py
src/Chinese_ChatBot/QQ_ETL.py
+159
-0
src/Chinese_ChatBot/run_demo.py
src/Chinese_ChatBot/run_demo.py
+7
-6
src/Chinese_ChatBot/run_train.py
src/Chinese_ChatBot/run_train.py
+33
-31
src/Chinese_ChatBot/u_tools.py
src/Chinese_ChatBot/u_tools.py
+4
-4
未找到文件。
src/Chinese_ChatBot/QQ_ETL.py
0 → 100644
浏览文件 @
0303e3d7
#!/usr/bin/python
# coding: utf-8
import
re
import
codecs
import
pandas
as
pd
script_name
=
"QQ聊天记录整理"
# 1、通过正则语句,提取出所有的记录头和记录内容两个数组。一条记录头对应一条记录内容,所以两个数组长度应该相等。
# 2、处理记录内容
# 2.1、windows的换行为'\r\n',单'\n'体现不出换行效果。手机端导出的记录有的换行是\n,需要替换一下。
# 2.2、记录头放在了每条记录末行后面,为了记录头整齐美观,需要计算一下记录头前补多少空格。windows记事本显示中文字符占两格,英文占1格,而python中文字符长度是却是1,如果想要显示整齐,还需要计算一下,然后补齐空格数。补齐后记录头距离行首位置为100的整数倍。
# 3、读和写文件的时候注意编码转换
def
length_w
(
text
):
'''计算字符串在windows记事本中的实际显示长度'''
# 取文本长度,中文按2格计算。
length
=
len
(
text
)
# 取其长度(中文字符长度为1,英文1)
utf8_length
=
len
(
text
.
encode
(
'utf-8'
))
# 取其长度(中文长3,英文1)
length
=
int
((
utf8_length
-
length
)
/
2
)
+
length
# 按(中文2英文1)计算长度
# 这个写法实际上还是有问题的,有些特殊字符会导致计算长度和实际显示长度不一致。所以下面计算换行问题的代码中换了另一种写法,避免因特殊字符导致每行实际显示长度超出限定值,虽然还是不精确,但是不会超出限定值。
# 比如:
# '°'在记事本中显示占2格,b'\xc2\xb0'utf-8编码长度为2。
# '�'在记事本中显示占1格,b'\xef\xbf\xbd'utf-8编码长度为3。
# ''在记事本中显示占2格,b'\x01'utf-8编码长度为1。(特殊字符无法显示)
# 至于特殊'\t'制表符最好最开始就用四个空格替换掉,避免其自动缩进带来的影响
return
length
def
chinese_linefeed
(
text
,
limit
):
'''中英文混合排版,限制单行长度,超出长度换行'''
text_format
=
''
# 结果变量,初始化
text
=
text
.
replace
(
'
\t
'
,
' '
)
text
=
text
.
replace
(
'
\r\n
'
,
'
\n
'
)
text_arr
=
text
.
split
(
'
\n
'
)
# 按行分割文本
for
line
in
text_arr
:
# 逐行处理
text_format
+=
'
\r\n
'
num
=
0
# 长度计数变量,初始化
for
i
in
line
:
# 从该行第一个字符起计算长度
# 中文长度为2
# asc2码(英文及其字符等)长度为1
# 其他长度为2(一些特殊)
if
i
>=
u
'
\u4e00
'
and
i
<=
u
'
\u9fa5
'
:
char_len
=
2
elif
i
>=
u
'
\u001c
'
and
i
<=
u
'
\u00ff
'
:
char_len
=
1
else
:
char_len
=
2
# 累计长度小于limit,直接保存至结果变量,计数变量累加
# 累计长度大于limit,换行后再保存,计数变量重置
if
num
+
char_len
<=
limit
:
text_format
+=
i
num
+=
char_len
else
:
text_format
+=
'
\r\n
'
+
i
num
=
char_len
return
text_format
.
strip
()
def
format_chat_data
(
infile
,
outfile
):
"""
# QQ聊天记录手机端导出文本
"""
# 读取文件
fp
=
codecs
.
open
(
infile
,
'r'
,
'utf-8'
)
txt
=
fp
.
read
()
fp
.
close
()
re_pat
=
r
'20[\d-]{8}\s[\d:]{7,8}\s+[^\n]+(?:\d{5,11}|@\w+\.[comnet]{2,3})\)'
# 正则语句,匹配记录头
log_title_arr
=
re
.
findall
(
re_pat
,
txt
)
# 记录头数组['2016-06-24 15:42:52 张某(40**21)',…]
log_content_arr
=
re
.
split
(
re_pat
,
txt
)
# 记录内容数组['\n', '\n选修的\n\n', '\n就怕这次…]
log_content_arr
.
pop
(
0
)
# 剔除掉第一个(分割造成的冗余部分)
# 数组长度
l1
=
len
(
log_title_arr
)
l2
=
len
(
log_content_arr
)
print
(
'记录头数: %d
\n
记录内容: %d'
%
(
l1
,
l2
))
if
l1
==
l2
:
# 整理后的记录
log_format
=
''
# 开始整理
for
i
in
range
(
0
,
l1
):
title
=
log_title_arr
[
i
]
# 记录头
content
=
log_content_arr
[
i
].
strip
()
# 删记录内容首尾空白字符
content
=
content
.
replace
(
'
\r\n
'
,
'
\n
'
)
# 记录中的'\n',替换为'\r\n'
content
=
content
.
replace
(
'
\n
'
,
'
\r\n
'
)
content
=
chinese_linefeed
(
content
,
100
)
# 每行过长自动换行
lastline
=
content
.
split
(
'
\r\n
'
)[
-
1
]
# 取记录内容最后一行
length
=
length_w
(
lastline
)
# 取其长度
# space = (100-(length%100))*' ' if length%100!=0 else ''# 该行记录头前补空格,变整齐为100整数倍;余数为0则不用补空格
space
=
' | '
# 该行记录头前补空格,变整齐为100整数倍;余数为0则不用补空格
log_format
+=
content
+
space
+
'['
+
title
+
']
\r\n
'
# 拼接合成记录
# 写到文件
fp
=
codecs
.
open
(
outfile
,
'w'
,
'utf-8'
)
fp
.
write
(
log_format
)
fp
.
close
()
print
(
"整理完毕~^_^~"
)
else
:
print
(
'记录头和记录内容条数不匹配,请修正代码'
)
def
split_line
(
line
):
l
=
re
.
sub
(
r
"[\[\]]+"
,
""
,
str
(
line
).
strip
()).
split
(
" | "
)
if
len
(
l
)
==
2
:
content
=
l
[
0
]
names
=
l
[
1
].
split
(
" "
)
if
names
==
3
:
c_time
=
names
[
0
]
+
names
[
1
]
c_id
=
names
[
2
]
print
([
content
,
c_time
,
c_id
])
# return "%s | %s | %s" % (content, c_time, c_id)
return
content
return
content
return
""
# Extracts pairs of sentences from conversations
def
extractSentencePairs
(
conversation
):
qa_pairs
=
[]
for
i
in
range
(
len
(
conversation
)
-
1
):
# We ignore the last line (no answer for it)
inputLine
=
conversation
[
i
].
strip
()
targetLine
=
conversation
[
i
+
1
].
strip
()
# Filter wrong samples (if one of the lists is empty)
if
inputLine
and
targetLine
:
qa_pairs
.
append
(
"%s | %s"
%
(
inputLine
,
targetLine
))
return
qa_pairs
def
format_2
(
infile
,
outfile
):
df
=
pd
.
read_csv
(
infile
,
sep
=
'
\000
01'
,
header
=
None
,
names
=
[
"txt"
])
# print(df["txt"].head(5))
df
[
"content"
]
=
df
[
"txt"
].
apply
(
lambda
line
:
split_line
(
line
))
# df.query("content!=''")["content"].to_csv(outfile, sep="\t", header=False, index=False)
lines
=
df
.
query
(
"content!=''"
)[
"content"
].
tolist
()
# print(lines)
chats
=
extractSentencePairs
(
lines
)
df_chats
=
pd
.
DataFrame
(
chats
,
columns
=
[
'lines'
])
df_chats
.
to_csv
(
outfile
,
sep
=
"
\t
"
,
header
=
False
,
index
=
False
)
print
(
">>> 数据合并成功: %s"
%
outfile
)
if
__name__
==
"__main__"
:
infile
=
r
'data/QQChat/ML_ApacheCN.csv'
outfile_1
=
r
'data/QQChat/format_1.csv'
outfile_2
=
r
'data/QQChat/format_2.csv'
# format_chat_data(infile, outfile_1)
format_2
(
outfile_1
,
outfile_2
)
\ No newline at end of file
src/Chinese_ChatBot/run_demo.py
浏览文件 @
0303e3d7
...
@@ -13,7 +13,7 @@ from u_class import Voc, GreedySearchDecoder
...
@@ -13,7 +13,7 @@ from u_class import Voc, GreedySearchDecoder
PAD_token
=
0
# Used for padding short sentences
PAD_token
=
0
# Used for padding short sentences
SOS_token
=
1
# Start-of-sentence token
SOS_token
=
1
# Start-of-sentence token
EOS_token
=
2
# End-of-sentence token
EOS_token
=
2
# End-of-sentence token
MAX_LENGTH
=
1
0
# Maximum sentence length to consider
MAX_LENGTH
=
5
0
# Maximum sentence length to consider
def
indexesFromSentence
(
voc
,
sentence
):
def
indexesFromSentence
(
voc
,
sentence
):
...
@@ -77,12 +77,15 @@ if __name__ == "__main__":
...
@@ -77,12 +77,15 @@ if __name__ == "__main__":
cp_start_iteration
=
0
cp_start_iteration
=
0
learning_rate
=
0.0001
learning_rate
=
0.0001
decoder_learning_ratio
=
5.0
decoder_learning_ratio
=
5.0
n_iteration
=
5
000
n_iteration
=
8
000
voc
=
Voc
(
corpus_name
)
loadFilename
=
"data/save/cb_model/%s/2-2_500/%s_checkpoint.tar"
%
(
corpus_name
,
n_iteration
)
loadFilename
=
"data/save/cb_model/%s/2-2_500/%s_checkpoint.tar"
%
(
corpus_name
,
n_iteration
)
if
os
.
path
.
exists
(
loadFilename
):
if
os
.
path
.
exists
(
loadFilename
):
voc
=
Voc
(
corpus_name
)
checkpoint
=
torch
.
load
(
loadFilename
)
cp_start_iteration
,
voc
,
encoder
,
decoder
,
encoder_optimizer
,
decoder_optimizer
,
embedding
=
load_model
(
loadFilename
,
voc
,
cp_start_iteration
,
attn_model
,
hidden_size
,
encoder_n_layers
,
decoder_n_layers
,
dropout
,
learning_rate
,
decoder_learning_ratio
)
voc
.
__dict__
=
checkpoint
[
'voc_dict'
]
cp_start_iteration
,
encoder
,
decoder
,
encoder_optimizer
,
decoder_optimizer
,
embedding
=
load_model
(
loadFilename
,
voc
,
cp_start_iteration
,
attn_model
,
hidden_size
,
encoder_n_layers
,
decoder_n_layers
,
dropout
,
learning_rate
,
decoder_learning_ratio
)
# Use appropriate device
# Use appropriate device
encoder
=
encoder
.
to
(
device
)
encoder
=
encoder
.
to
(
device
)
...
@@ -96,5 +99,3 @@ if __name__ == "__main__":
...
@@ -96,5 +99,3 @@ if __name__ == "__main__":
# Begin chatting (uncomment and run the following line to begin)
# Begin chatting (uncomment and run the following line to begin)
evaluateInput
(
encoder
,
decoder
,
searcher
,
voc
)
evaluateInput
(
encoder
,
decoder
,
searcher
,
voc
)
src/Chinese_ChatBot/run_train.py
浏览文件 @
0303e3d7
...
@@ -67,14 +67,13 @@ def outputVar(l, voc):
...
@@ -67,14 +67,13 @@ def outputVar(l, voc):
# 初始化Voc对象 和 格式化pairs对话存放到list中
# 初始化Voc对象 和 格式化pairs对话存放到list中
def
readVocs
(
datafile
,
corpus_name
):
def
readVocs
(
datafile
):
print
(
"Reading lines..."
)
print
(
"Reading lines..."
)
# Read the file and split into lines
# Read the file and split into lines
lines
=
open
(
datafile
,
encoding
=
'utf-8'
).
read
().
strip
().
split
(
'
\n
'
)
lines
=
open
(
datafile
,
encoding
=
'utf-8'
).
read
().
strip
().
split
(
'
\n
'
)
# Split every line into pairs and normalize
# Split every line into pairs and normalize
pairs
=
[[
normalizeString
(
s
)
for
s
in
l
.
split
(
' | '
)]
for
l
in
lines
]
pairs
=
[[
normalizeString
(
s
)
for
s
in
l
.
split
(
' | '
)]
for
l
in
lines
]
voc
=
Voc
(
corpus_name
)
return
pairs
return
voc
,
pairs
# 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值,则返回True
# 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值,则返回True
...
@@ -88,9 +87,9 @@ def filterPairs(pairs):
...
@@ -88,9 +87,9 @@ def filterPairs(pairs):
return
[
pair
for
pair
in
pairs
if
filterPair
(
pair
)]
return
[
pair
for
pair
in
pairs
if
filterPair
(
pair
)]
# 使用上面定义的函数,返回一个填充的voc对象和对列表
# 使用上面定义的函数,返回一个填充的voc对象和对列表
def
loadPrepareData
(
corpus
,
corpus_name
,
datafile
,
save_dir
):
def
loadPrepareData
(
corpus
,
corpus_name
,
datafile
,
voc
,
save_dir
):
print
(
"Start preparing training data ..."
)
print
(
"Start preparing training data ..."
)
voc
,
pairs
=
readVocs
(
datafile
,
corpus_nam
e
)
pairs
=
readVocs
(
datafil
e
)
print
(
"Read {!s} sentence pairs"
.
format
(
len
(
pairs
)))
print
(
"Read {!s} sentence pairs"
.
format
(
len
(
pairs
)))
pairs
=
filterPairs
(
pairs
)
pairs
=
filterPairs
(
pairs
)
print
(
"Trimmed to {!s} sentence pairs"
.
format
(
len
(
pairs
)))
print
(
"Trimmed to {!s} sentence pairs"
.
format
(
len
(
pairs
)))
...
@@ -275,29 +274,9 @@ def TrainModel():
...
@@ -275,29 +274,9 @@ def TrainModel():
corpus_name
=
"Chinese_ChatBot"
corpus_name
=
"Chinese_ChatBot"
corpus
=
os
.
path
.
join
(
"data"
,
corpus_name
)
corpus
=
os
.
path
.
join
(
"data"
,
corpus_name
)
datafile
=
os
.
path
.
join
(
corpus
,
"formatted_data.csv"
)
datafile
=
os
.
path
.
join
(
corpus
,
"format_data.csv"
)
# Load/Assemble voc and pairs
save_dir
=
os
.
path
.
join
(
"data"
,
"save"
)
save_dir
=
os
.
path
.
join
(
"data"
,
"save"
)
voc
,
pairs
=
loadPrepareData
(
corpus
,
corpus_name
,
datafile
,
save_dir
)
# Print some pairs to validate
print
(
"
\n
pairs:"
)
for
pair
in
pairs
[:
10
]:
print
(
pair
)
# Trim voc and pairs
pairs
=
trimRareWords
(
voc
,
pairs
,
MIN_COUNT
)
# Example for validation
small_batch_size
=
5
batches
=
batch2TrainData
(
voc
,
[
random
.
choice
(
pairs
)
for
_
in
range
(
small_batch_size
)])
input_variable
,
lengths
,
target_variable
,
mask
,
max_target_len
=
batches
print
(
"input_variable:"
,
input_variable
)
print
(
"lengths:"
,
lengths
)
print
(
"target_variable:"
,
target_variable
)
print
(
"mask:"
,
mask
)
print
(
"max_target_len:"
,
max_target_len
)
global
teacher_forcing_ratio
,
hidden_size
global
teacher_forcing_ratio
,
hidden_size
# Configure models
# Configure models
model_name
=
'cb_model'
model_name
=
'cb_model'
...
@@ -316,12 +295,35 @@ def TrainModel():
...
@@ -316,12 +295,35 @@ def TrainModel():
print_every
=
1
print_every
=
1
batch_size
=
64
batch_size
=
64
save_every
=
1000
save_every
=
1000
n_iteration
=
5
000
n_iteration
=
8
000
loadFilename
=
"data/save/cb_model/%s/2-2_500/%s_checkpoint.tar"
%
(
corpus_name
,
n_iteration
)
voc
=
Voc
(
corpus_name
)
loadFilename
=
"data/save/cb_model/%s/2-2_500/5000_checkpoint.tar"
%
(
corpus_name
)
if
os
.
path
.
exists
(
loadFilename
):
if
os
.
path
.
exists
(
loadFilename
):
voc
=
Voc
(
corpus_name
)
checkpoint
=
torch
.
load
(
loadFilename
)
cp_start_iteration
,
voc
,
encoder
,
decoder
,
encoder_optimizer
,
decoder_optimizer
,
embedding
=
load_model
(
loadFilename
,
voc
,
cp_start_iteration
,
attn_model
,
hidden_size
,
encoder_n_layers
,
decoder_n_layers
,
dropout
,
learning_rate
,
decoder_learning_ratio
)
voc
.
__dict__
=
checkpoint
[
'voc_dict'
]
# Load/Assemble voc and pairs
voc
,
pairs
=
loadPrepareData
(
corpus
,
corpus_name
,
datafile
,
voc
,
save_dir
)
# Print some pairs to validate
print
(
"
\n
pairs:"
)
for
pair
in
pairs
[:
10
]:
print
(
pair
)
# Trim voc and pairs
pairs
=
trimRareWords
(
voc
,
pairs
,
MIN_COUNT
)
# # Example for validation
# small_batch_size = 5
# batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
# input_variable, lengths, target_variable, mask, max_target_len = batches
# print("input_variable:", input_variable)
# print("lengths:", lengths)
# print("target_variable:", target_variable)
# print("mask:", mask)
# print("max_target_len:", max_target_len)
cp_start_iteration
,
encoder
,
decoder
,
encoder_optimizer
,
decoder_optimizer
,
embedding
=
load_model
(
loadFilename
,
voc
,
cp_start_iteration
,
attn_model
,
hidden_size
,
encoder_n_layers
,
decoder_n_layers
,
dropout
,
learning_rate
,
decoder_learning_ratio
)
# Use appropriate device
# Use appropriate device
encoder
=
encoder
.
to
(
device
)
encoder
=
encoder
.
to
(
device
)
...
...
src/Chinese_ChatBot/u_tools.py
浏览文件 @
0303e3d7
...
@@ -21,9 +21,9 @@ def unicodeToAscii(s):
...
@@ -21,9 +21,9 @@ def unicodeToAscii(s):
# Lowercase, trim, and remove non-letter characters
# Lowercase, trim, and remove non-letter characters
def
normalizeString
(
s
):
def
normalizeString
(
s
):
s
=
unicodeToAscii
(
s
.
lower
().
strip
())
s
=
unicodeToAscii
(
s
.
lower
().
strip
())
s
=
re
.
sub
(
r
"([.!?])"
,
r
" \1"
,
s
)
s
=
re
.
sub
(
r
"([.!?
])"
,
r
" \1"
,
s
)
# s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
# s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
s
=
re
.
sub
(
r
"[^a-zA-Z.!?\u4E00-\u9FA5]+"
,
r
" "
,
s
)
s
=
re
.
sub
(
r
"[^a-zA-Z.!?
\u4E00-\u9FA5]+"
,
r
" "
,
s
)
s
=
re
.
sub
(
r
"\s+"
,
r
" "
,
s
).
strip
()
s
=
re
.
sub
(
r
"\s+"
,
r
" "
,
s
).
strip
()
# '咋死 ? ? ?红烧还是爆炒dddd' > '咋 死 ? ? ? 红 烧 还 是 爆 炒 d d d d'
# '咋死 ? ? ?红烧还是爆炒dddd' > '咋 死 ? ? ? 红 烧 还 是 爆 炒 d d d d'
s
=
" "
.
join
(
list
(
s
))
s
=
" "
.
join
(
list
(
s
))
...
@@ -43,7 +43,7 @@ def load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, e
...
@@ -43,7 +43,7 @@ def load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, e
encoder_optimizer_sd
=
checkpoint
[
'state_dict_en_opt'
]
encoder_optimizer_sd
=
checkpoint
[
'state_dict_en_opt'
]
decoder_optimizer_sd
=
checkpoint
[
'state_dict_de_opt'
]
decoder_optimizer_sd
=
checkpoint
[
'state_dict_de_opt'
]
# loss = checkpoint['loss']
# loss = checkpoint['loss']
voc
.
__dict__
=
checkpoint
[
'voc_dict'
]
#
voc.__dict__ = checkpoint['voc_dict']
embedding_sd
=
checkpoint
[
'embedding'
]
embedding_sd
=
checkpoint
[
'embedding'
]
print
(
'Building encoder and decoder ...'
)
print
(
'Building encoder and decoder ...'
)
...
@@ -65,4 +65,4 @@ def load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, e
...
@@ -65,4 +65,4 @@ def load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, e
decoder_optimizer
.
load_state_dict
(
decoder_optimizer_sd
)
decoder_optimizer
.
load_state_dict
(
decoder_optimizer_sd
)
print
(
'Models built and ready to go!'
)
print
(
'Models built and ready to go!'
)
return
cp_start_iteration
,
voc
,
encoder
,
decoder
,
encoder_optimizer
,
decoder_optimizer
,
embedding
return
cp_start_iteration
,
encoder
,
decoder
,
encoder_optimizer
,
decoder_optimizer
,
embedding
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录