From ebf3b702fc06707181cbba616d5db84a8e09e0b6 Mon Sep 17 00:00:00 2001 From: jiangzhonglian Date: Mon, 29 Jul 2019 18:43:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=B8=AD=E6=96=87=E7=89=88?= =?UTF-8?q?=E6=9C=AC=20ChatBot?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Chinese_ChatBot/run_demo.py | 8 +- src/Chinese_ChatBot/run_train.py | 206 +++++++++---------------------- src/Chinese_ChatBot/u_tools.py | 5 +- 3 files changed, 70 insertions(+), 149 deletions(-) diff --git a/src/Chinese_ChatBot/run_demo.py b/src/Chinese_ChatBot/run_demo.py index 066ff0c..2589040 100644 --- a/src/Chinese_ChatBot/run_demo.py +++ b/src/Chinese_ChatBot/run_demo.py @@ -52,7 +52,8 @@ def evaluateInput(encoder, decoder, searcher, voc): output_words = evaluate(encoder, decoder, searcher, voc, input_sentence) # Format and print response sentence output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')] - print('Bot:', ' '.join(output_words)) + # print('Bot:', ' '.join(output_words)) + print('Bot:', ''.join(output_words)) except KeyError: print("Error: Encountered unknown word.") @@ -62,7 +63,7 @@ if __name__ == "__main__": global device, corpus_name USE_CUDA = torch.cuda.is_available() device = torch.device("cuda" if USE_CUDA else "cpu") - corpus_name = "cornell_movie-dialogs_corpus" + corpus_name = "Chinese_ChatBot" # Configure models @@ -76,8 +77,9 @@ if __name__ == "__main__": cp_start_iteration = 0 learning_rate = 0.0001 decoder_learning_ratio = 5.0 + n_iteration = 5000 - loadFilename = "data/save_copy/cb_model/%s/2-2_500/6000_checkpoint.tar" % corpus_name + loadFilename = "data/save/cb_model/%s/2-2_500/%s_checkpoint.tar" % (corpus_name, n_iteration) if os.path.exists(loadFilename): voc = Voc(corpus_name) cp_start_iteration, voc, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding = load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, encoder_n_layers, decoder_n_layers, dropout, learning_rate, decoder_learning_ratio) diff --git a/src/Chinese_ChatBot/run_train.py b/src/Chinese_ChatBot/run_train.py index ff15c0c..1e0f7e9 100644 --- a/src/Chinese_ChatBot/run_train.py +++ b/src/Chinese_ChatBot/run_train.py @@ -19,7 +19,7 @@ from u_class import Voc PAD_token = 0 # Used for padding short sentences SOS_token = 1 # Start-of-sentence token EOS_token = 2 # End-of-sentence token -MAX_LENGTH = 10 # Maximum sentence length to consider +MAX_LENGTH = 50 # Maximum sentence length to consider MIN_COUNT = 3 # Minimum word count threshold for trimming @@ -29,84 +29,6 @@ def printLines(file, n=10): for line in lines[:n]: print(line) -# Splits each line of the file into a dictionary of fields -def loadLines(fileName, fields): - lines = {} - with open(fileName, 'r', encoding='iso-8859-1') as f: - for line in f: - values = line.split(" +++$+++ ") - # Extract fields - lineObj = {} - for i, field in enumerate(fields): - lineObj[field] = values[i] - lines[lineObj['lineID']] = lineObj - return lines - -# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt* -def loadConversations(fileName, lines, fields): - conversations = [] - with open(fileName, 'r', encoding='iso-8859-1') as f: - for line in f: - values = line.split(" +++$+++ ") - # Extract fields - convObj = {} - for i, field in enumerate(fields): - convObj[field] = values[i] - # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]") - lineIds = eval(convObj["utteranceIDs"]) - # Reassemble lines - convObj["lines"] = [] - for lineId in lineIds: - convObj["lines"].append(lines[lineId]) - conversations.append(convObj) - return conversations - -# Extracts pairs of sentences from conversations -def extractSentencePairs(conversations): - qa_pairs = [] - for conversation in conversations: - # Iterate over all the lines of the conversation - for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it) - inputLine = conversation["lines"][i]["text"].strip() - targetLine = conversation["lines"][i+1]["text"].strip() - # Filter wrong samples (if one of the lists is empty) - if inputLine and targetLine: - qa_pairs.append([inputLine, targetLine]) - return qa_pairs - - - -def get_datafile(datafile): - # Define path to new file - delimiter = '\t' - # Unescape the delimiter - delimiter = str(codecs.decode(delimiter, "unicode_escape")) - - # Initialize lines dict, conversations list, and field ids - lines = {} - conversations = [] - MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"] - MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"] - - # Load lines and process conversations - print("\nProcessing corpus...") - lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS) - # print(">>> ", lines) - print("\nLoading conversations...") - conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"), lines, MOVIE_CONVERSATIONS_FIELDS) - # print(">>> ", conversations) - - # Write new csv file - print("\nWriting newly formatted file...") - with open(datafile, 'w', encoding='utf-8') as outputfile: - writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n') - for pair in extractSentencePairs(conversations): - writer.writerow(pair) - - # Print a sample of lines - print("\nSample lines from file:") - printLines(datafile) - def indexesFromSentence(voc, sentence): return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token] @@ -148,18 +70,19 @@ def outputVar(l, voc): def readVocs(datafile, corpus_name): print("Reading lines...") # Read the file and split into lines - lines = open(datafile, encoding='utf-8').\ - read().strip().split('\n') + lines = open(datafile, encoding='utf-8').read().strip().split('\n') # Split every line into pairs and normalize - pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines] + pairs = [[normalizeString(s) for s in l.split(' | ')] for l in lines] voc = Voc(corpus_name) return voc, pairs + # 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值,则返回True def filterPair(p): # Input sequences need to preserve the last word for EOS token return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH + # 过滤满足条件的 pairs 对话 def filterPairs(pairs): return [pair for pair in pairs if filterPair(pair)] @@ -345,65 +268,14 @@ def trainIters(model_name, cp_start_iteration, voc, pairs, encoder, decoder, enc }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint'))) -# if __name__ == "__main__": - -# global teacher_forcing_ratio, hidden_size -# # Configure models -# model_name = 'cb_model' -# attn_model = 'dot' -# #attn_model = 'general' -# #attn_model = 'concat' -# hidden_size = 500 -# encoder_n_layers = 2 -# decoder_n_layers = 2 -# dropout = 0.1 -# cp_start_iteration = 0 -# learning_rate = 0.0001 -# decoder_learning_ratio = 5.0 -# teacher_forcing_ratio = 1.0 -# clip = 50.0 -# print_every = 1 -# batch_size = 64 -# save_every = 1000 -# n_iteration = 7000 - -# loadFilename = "data/save/cb_model/%s/2-2_500/6000_checkpoint.tar" % corpus_name -# if os.path.exists(loadFilename): -# voc = Voc(corpus_name) -# cp_start_iteration, voc, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding = load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, encoder_n_layers, decoder_n_layers, dropout, learning_rate, decoder_learning_ratio) - -# # Use appropriate device -# encoder = encoder.to(device) -# decoder = decoder.to(device) -# for state in encoder_optimizer.state.values(): -# for k, v in state.items(): -# if isinstance(v, torch.Tensor): -# state[k] = v.cuda() - -# for state in decoder_optimizer.state.values(): -# for k, v in state.items(): -# if isinstance(v, torch.Tensor): -# state[k] = v.cuda() - -# # Ensure dropout layers are in train mode -# encoder.train() -# decoder.train() - -# print("Starting Training!") -# trainIters(model_name, cp_start_iteration, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name) - - -def train(p1=0, p2=0): +def TrainModel(): global device, corpus_name USE_CUDA = torch.cuda.is_available() device = torch.device("cuda" if USE_CUDA else "cpu") corpus_name = "Chinese_ChatBot" corpus = os.path.join("data", corpus_name) - # printLines(os.path.join(corpus, "movie_lines.txt")) - - datafile = os.path.join(corpus, "formatted_movie_lines.txt") - get_datafile(datafile) + datafile = os.path.join(corpus, "formatted_data.csv") # Load/Assemble voc and pairs save_dir = os.path.join("data", "save") @@ -416,16 +288,60 @@ def train(p1=0, p2=0): # Trim voc and pairs pairs = trimRareWords(voc, pairs, MIN_COUNT) - # # Example for validation - # small_batch_size = 5 - # batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)]) - # input_variable, lengths, target_variable, mask, max_target_len = batches - # print("input_variable:", input_variable) - # print("lengths:", lengths) - # print("target_variable:", target_variable) - # print("mask:", mask) - # print("max_target_len:", max_target_len) - + # Example for validation + small_batch_size = 5 + batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)]) + input_variable, lengths, target_variable, mask, max_target_len = batches + print("input_variable:", input_variable) + print("lengths:", lengths) + print("target_variable:", target_variable) + print("mask:", mask) + print("max_target_len:", max_target_len) + + global teacher_forcing_ratio, hidden_size + # Configure models + model_name = 'cb_model' + attn_model = 'dot' + #attn_model = 'general' + #attn_model = 'concat' + hidden_size = 500 + encoder_n_layers = 2 + decoder_n_layers = 2 + dropout = 0.1 + cp_start_iteration = 0 + learning_rate = 0.0001 + decoder_learning_ratio = 5.0 + teacher_forcing_ratio = 1.0 + clip = 50.0 + print_every = 1 + batch_size = 64 + save_every = 1000 + n_iteration = 5000 + + loadFilename = "data/save/cb_model/%s/2-2_500/%s_checkpoint.tar" % (corpus_name, n_iteration) + if os.path.exists(loadFilename): + voc = Voc(corpus_name) + cp_start_iteration, voc, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding = load_model(loadFilename, voc, cp_start_iteration, attn_model, hidden_size, encoder_n_layers, decoder_n_layers, dropout, learning_rate, decoder_learning_ratio) + + # Use appropriate device + encoder = encoder.to(device) + decoder = decoder.to(device) + for state in encoder_optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda() + + for state in decoder_optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda() + + # Ensure dropout layers are in train mode + encoder.train() + decoder.train() + + print("Starting Training!") + trainIters(model_name, cp_start_iteration, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name) if __name__ == "__main__": diff --git a/src/Chinese_ChatBot/u_tools.py b/src/Chinese_ChatBot/u_tools.py index 462072a..c23d5eb 100644 --- a/src/Chinese_ChatBot/u_tools.py +++ b/src/Chinese_ChatBot/u_tools.py @@ -22,8 +22,11 @@ def unicodeToAscii(s): def normalizeString(s): s = unicodeToAscii(s.lower().strip()) s = re.sub(r"([.!?])", r" \1", s) - s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + # s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + s = re.sub(r"[^a-zA-Z.!?\u4E00-\u9FA5]+", r" ", s) s = re.sub(r"\s+", r" ", s).strip() + # '咋死 ? ? ?红烧还是爆炒dddd' > '咋 死 ? ? ? 红 烧 还 是 爆 炒 d d d d' + s = " ".join(list(s)) return s -- GitLab