提交 ebf3b702 编写于 作者: 片刻小哥哥's avatar 片刻小哥哥

更新中文版本 ChatBot

上级 dee47b66
......@@ -52,7 +52,8 @@ def evaluateInput(encoder, decoder, searcher, voc):
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
# Format and print response sentence
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
# print('Bot:', ' '.join(output_words))
print('Bot:', ''.join(output_words))
except KeyError:
print("Error: Encountered unknown word.")
......@@ -62,7 +63,7 @@ if __name__ == "__main__":
global device, corpus_name
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
corpus_name = "cornell_movie-dialogs_corpus"
corpus_name = "Chinese_ChatBot"
# Configure models
......@@ -76,8 +77,9 @@ if __name__ == "__main__":
cp_start_iteration = 0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 5000
loadFilename = "data/save_copy/cb_model/%s/2-2_500/6000_checkpoint.tar" % corpus_name
loadFilename = "data/save/cb_model/%s/2-2_500/%s_checkpoint.tar" % (corpus_name, n_iteration)
if os.path.exists(loadFilename):
voc = Voc(corpus_name)
cp_start_iteration, voc, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding = load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, encoder_n_layers, decoder_n_layers, dropout, learning_rate, decoder_learning_ratio)
......
......@@ -19,7 +19,7 @@ from u_class import Voc
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token
MAX_LENGTH = 10 # Maximum sentence length to consider
MAX_LENGTH = 50 # Maximum sentence length to consider
MIN_COUNT = 3 # Minimum word count threshold for trimming
......@@ -29,84 +29,6 @@ def printLines(file, n=10):
for line in lines[:n]:
print(line)
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
lines = {}
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
lineObj = {}
for i, field in enumerate(fields):
lineObj[field] = values[i]
lines[lineObj['lineID']] = lineObj
return lines
# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
conversations = []
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
convObj = {}
for i, field in enumerate(fields):
convObj[field] = values[i]
# Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
lineIds = eval(convObj["utteranceIDs"])
# Reassemble lines
convObj["lines"] = []
for lineId in lineIds:
convObj["lines"].append(lines[lineId])
conversations.append(convObj)
return conversations
# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
qa_pairs = []
for conversation in conversations:
# Iterate over all the lines of the conversation
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
inputLine = conversation["lines"][i]["text"].strip()
targetLine = conversation["lines"][i+1]["text"].strip()
# Filter wrong samples (if one of the lists is empty)
if inputLine and targetLine:
qa_pairs.append([inputLine, targetLine])
return qa_pairs
def get_datafile(datafile):
# Define path to new file
delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
# print(">>> ", lines)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"), lines, MOVIE_CONVERSATIONS_FIELDS)
# print(">>> ", conversations)
# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
for pair in extractSentencePairs(conversations):
writer.writerow(pair)
# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)
def indexesFromSentence(voc, sentence):
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
......@@ -148,18 +70,19 @@ def outputVar(l, voc):
def readVocs(datafile, corpus_name):
print("Reading lines...")
# Read the file and split into lines
lines = open(datafile, encoding='utf-8').\
read().strip().split('\n')
lines = open(datafile, encoding='utf-8').read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
pairs = [[normalizeString(s) for s in l.split(' | ')] for l in lines]
voc = Voc(corpus_name)
return voc, pairs
# 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值,则返回True
def filterPair(p):
# Input sequences need to preserve the last word for EOS token
return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH
# 过滤满足条件的 pairs 对话
def filterPairs(pairs):
return [pair for pair in pairs if filterPair(pair)]
......@@ -345,65 +268,14 @@ def trainIters(model_name, cp_start_iteration, voc, pairs, encoder, decoder, enc
}, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))
# if __name__ == "__main__":
# global teacher_forcing_ratio, hidden_size
# # Configure models
# model_name = 'cb_model'
# attn_model = 'dot'
# #attn_model = 'general'
# #attn_model = 'concat'
# hidden_size = 500
# encoder_n_layers = 2
# decoder_n_layers = 2
# dropout = 0.1
# cp_start_iteration = 0
# learning_rate = 0.0001
# decoder_learning_ratio = 5.0
# teacher_forcing_ratio = 1.0
# clip = 50.0
# print_every = 1
# batch_size = 64
# save_every = 1000
# n_iteration = 7000
# loadFilename = "data/save/cb_model/%s/2-2_500/6000_checkpoint.tar" % corpus_name
# if os.path.exists(loadFilename):
# voc = Voc(corpus_name)
# cp_start_iteration, voc, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding = load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, encoder_n_layers, decoder_n_layers, dropout, learning_rate, decoder_learning_ratio)
# # Use appropriate device
# encoder = encoder.to(device)
# decoder = decoder.to(device)
# for state in encoder_optimizer.state.values():
# for k, v in state.items():
# if isinstance(v, torch.Tensor):
# state[k] = v.cuda()
# for state in decoder_optimizer.state.values():
# for k, v in state.items():
# if isinstance(v, torch.Tensor):
# state[k] = v.cuda()
# # Ensure dropout layers are in train mode
# encoder.train()
# decoder.train()
# print("Starting Training!")
# trainIters(model_name, cp_start_iteration, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name)
def train(p1=0, p2=0):
def TrainModel():
global device, corpus_name
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
corpus_name = "Chinese_ChatBot"
corpus = os.path.join("data", corpus_name)
# printLines(os.path.join(corpus, "movie_lines.txt"))
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
get_datafile(datafile)
datafile = os.path.join(corpus, "formatted_data.csv")
# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
......@@ -416,16 +288,60 @@ def train(p1=0, p2=0):
# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)
# # Example for validation
# small_batch_size = 5
# batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
# input_variable, lengths, target_variable, mask, max_target_len = batches
# print("input_variable:", input_variable)
# print("lengths:", lengths)
# print("target_variable:", target_variable)
# print("mask:", mask)
# print("max_target_len:", max_target_len)
# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)
global teacher_forcing_ratio, hidden_size
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
cp_start_iteration = 0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
teacher_forcing_ratio = 1.0
clip = 50.0
print_every = 1
batch_size = 64
save_every = 1000
n_iteration = 5000
loadFilename = "data/save/cb_model/%s/2-2_500/%s_checkpoint.tar" % (corpus_name, n_iteration)
if os.path.exists(loadFilename):
voc = Voc(corpus_name)
cp_start_iteration, voc, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding = load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, encoder_n_layers, decoder_n_layers, dropout, learning_rate, decoder_learning_ratio)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
for state in encoder_optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()
for state in decoder_optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()
# Ensure dropout layers are in train mode
encoder.train()
decoder.train()
print("Starting Training!")
trainIters(model_name, cp_start_iteration, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name)
if __name__ == "__main__":
......
......@@ -22,8 +22,11 @@ def unicodeToAscii(s):
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
# s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
s = re.sub(r"[^a-zA-Z.!?\u4E00-\u9FA5]+", r" ", s)
s = re.sub(r"\s+", r" ", s).strip()
# '咋死 ? ? ?红烧还是爆炒dddd' > '咋 死 ? ? ? 红 烧 还 是 爆 炒 d d d d'
s = " ".join(list(s))
return s
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册