提交 7db0d3df 编写于 作者: 片刻小哥哥's avatar 片刻小哥哥

添加 attention 机制优化结果

上级 c78b1c02
......@@ -5,6 +5,8 @@
# https://blog.csdn.net/alip39/article/details/95891321
# 参考代码:
# https://blog.csdn.net/u012052268/article/details/90238282
# Attention:
# https://github.com/philipperemy/keras-attention-mechanism
import re
import os
import keras
......@@ -16,7 +18,8 @@ import jieba
from sklearn.model_selection import train_test_split
from keras import Model
from keras.models import load_model
from keras.layers import Dropout, Dense, Flatten, Bidirectional, Embedding, GRU, Input
from keras.layers.normalization import BatchNormalization
from keras.layers import Dropout, Dense, Flatten, Bidirectional, Embedding, GRU, Input, multiply
from keras.preprocessing.sequence import pad_sequences
from keras.utils.np_utils import to_categorical
from keras.optimizers import Adam
......@@ -117,12 +120,15 @@ class EmotionModel(object):
# ) #设置句子的最大长度
print("开始训练模型.....")
# 使用
sequence_input = Input(shape=(self.MAX_SEQUENCE_LENGTH,), dtype='int32') # 返回一个张量,长度为1000,也就是模型的输入为batch_size*1000
embedded_sequences = embedding_layer(sequence_input) # 返回batch_size*1000*100
x = Bidirectional(GRU(100, return_sequences=True))(embedded_sequences)
x = Dropout(0.6)(x)
# 添加 注意力(本质上是通过加入 一个随机向量 作为 权重 来优化 输入的值 - 与全链接不同的是,这个还会作为输入项 和 输入做点乘 )
attention_probs = Dense(self.EMBEDDING_DIM, activation='softmax', name='attention_probs')(embedded_sequences)
attention_mul = multiply([embedded_sequences, attention_probs], name='attention_mul')
x = Bidirectional(GRU(self.EMBEDDING_DIM, return_sequences=True, dropout=0.5))(attention_mul)
x = Dropout(0.5)(x)
x = Flatten()(x)
# x = BatchNormalization()(x)
preds = Dense(self.pre_num, activation='softmax')(x)
self.model = Model(sequence_input, preds)
# 设置优化器
......@@ -176,7 +182,6 @@ class EmotionModel(object):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=10000)
return (x_train, y_train), (x_test, y_test)
def train(self):
'''训练模型'''
vocab_list, word_index, embeddings_matrix = load_embeding()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册