提交 aa3e072e 编写于 作者: S sneaxiy

refine dam model

上级 5d07cee8
......@@ -14,8 +14,13 @@ class Net(object):
self._emb_size = emb_size
self._stack_num = stack_num
self.word_emb_name = "shared_word_emb"
self.use_stack_op = True
self.use_mask_cache = True
self.use_sparse_embedding = True
def create_network(self):
mask_cache = dict() if self.use_mask_cache else None
turns_data = []
for i in xrange(self._max_turn_num):
turn = fluid.layers.data(
......@@ -28,19 +33,22 @@ class Net(object):
for i in xrange(self._max_turn_num):
turn_mask = fluid.layers.data(
name="turn_mask_%d" % i,
shape=[self._max_turn_len],
shape=[self._max_turn_len, 1],
dtype="float32")
turns_mask.append(turn_mask)
response = fluid.layers.data(
name="response", shape=[self._max_turn_len, 1], dtype="int32")
response_mask = fluid.layers.data(
name="response_mask", shape=[self._max_turn_len], dtype="float32")
name="response_mask",
shape=[self._max_turn_len, 1],
dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="float32")
response_emb = fluid.layers.embedding(
input=response,
size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr(
name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1)))
......@@ -57,7 +65,8 @@ class Net(object):
value=Hr,
d_key=self._emb_size,
q_mask=response_mask,
k_mask=response_mask)
k_mask=response_mask,
mask_cache=mask_cache)
Hr_stack.append(Hr)
# context part
......@@ -66,6 +75,7 @@ class Net(object):
Hu = fluid.layers.embedding(
input=turns_data[t],
size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr(
name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1)))
......@@ -80,7 +90,8 @@ class Net(object):
value=Hu,
d_key=self._emb_size,
q_mask=turns_mask[t],
k_mask=turns_mask[t])
k_mask=turns_mask[t],
mask_cache=mask_cache)
Hu_stack.append(Hu)
# cross attention
......@@ -94,7 +105,8 @@ class Net(object):
value=Hr_stack[index],
d_key=self._emb_size,
q_mask=turns_mask[t],
k_mask=response_mask)
k_mask=response_mask,
mask_cache=mask_cache)
r_a_t = layers.block(
name="r_attend_t_" + str(index),
query=Hr_stack[index],
......@@ -102,7 +114,8 @@ class Net(object):
value=Hu_stack[index],
d_key=self._emb_size,
q_mask=response_mask,
k_mask=turns_mask[t])
k_mask=turns_mask[t],
mask_cache=mask_cache)
t_a_r_stack.append(t_a_r)
r_a_t_stack.append(r_a_t)
......@@ -110,25 +123,32 @@ class Net(object):
t_a_r_stack.extend(Hu_stack)
r_a_t_stack.extend(Hr_stack)
for index in xrange(len(t_a_r_stack)):
t_a_r_stack[index] = fluid.layers.unsqueeze(
input=t_a_r_stack[index], axes=[1])
r_a_t_stack[index] = fluid.layers.unsqueeze(
input=r_a_t_stack[index], axes=[1])
if self.use_stack_op:
t_a_r = fluid.layers.stack(t_a_r_stack, axis=1)
r_a_t = fluid.layers.stack(r_a_t_stack, axis=1)
else:
for index in xrange(len(t_a_r_stack)):
t_a_r_stack[index] = fluid.layers.unsqueeze(
input=t_a_r_stack[index], axes=[1])
r_a_t_stack[index] = fluid.layers.unsqueeze(
input=r_a_t_stack[index], axes=[1])
t_a_r = fluid.layers.concat(input=t_a_r_stack, axis=1)
r_a_t = fluid.layers.concat(input=r_a_t_stack, axis=1)
t_a_r = fluid.layers.concat(input=t_a_r_stack, axis=1)
r_a_t = fluid.layers.concat(input=r_a_t_stack, axis=1)
# sim shape: [batch_size, 2*(stack_num+2), max_turn_len, max_turn_len]
sim = fluid.layers.matmul(x=t_a_r, y=r_a_t, transpose_y=True)
sim = fluid.layers.scale(x=sim, scale=1 / np.sqrt(200.0))
sim_turns.append(sim)
for index in xrange(len(sim_turns)):
sim_turns[index] = fluid.layers.unsqueeze(
input=sim_turns[index], axes=[2])
# sim shape: [batch_size, 2*(stack_num+2), max_turn_num, max_turn_len, max_turn_len]
sim = fluid.layers.concat(input=sim_turns, axis=2)
if self.use_stack_op:
sim = fluid.layers.stack(sim_turns, axis=2)
else:
for index in xrange(len(sim_turns)):
sim_turns[index] = fluid.layers.unsqueeze(
input=sim_turns[index], axes=[2])
# sim shape: [batch_size, 2*(stack_num+2), max_turn_num, max_turn_len, max_turn_len]
sim = fluid.layers.concat(input=sim_turns, axis=2)
# for douban
final_info = layers.cnn_3d(sim, 32, 16)
......
......@@ -52,7 +52,8 @@ def dot_product_attention(query,
d_key,
q_mask=None,
k_mask=None,
dropout_rate=None):
dropout_rate=None,
mask_cache=None):
"""Dot product layer.
Args:
......@@ -75,10 +76,17 @@ def dot_product_attention(query,
logits = logits * (d_key**(-0.5))
if (q_mask is not None) and (k_mask is not None):
q_mask = fluid.layers.unsqueeze(input=q_mask, axes=[-1])
k_mask = fluid.layers.unsqueeze(input=k_mask, axes=[-1])
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
logits = mask * logits + (1 - mask) * (-2**32 + 1)
if mask_cache is not None and q_mask.name in mask_cache and k_mask.name in mask_cache[
q_mask.name]:
mask, another_mask = mask_cache[q_mask.name][k_mask.name]
else:
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
another_mask = (1 - mask) * (-2**32 + 1)
if mask_cache is not None:
mask_cache[q_mask.name] = dict()
mask_cache[q_mask.name][k_mask.name] = [mask, another_mask]
logits = mask * logits + another_mask
attention = fluid.layers.softmax(logits)
if dropout_rate:
......@@ -98,12 +106,20 @@ def block(name,
q_mask=None,
k_mask=None,
is_layer_norm=True,
dropout_rate=None):
dropout_rate=None,
mask_cache=None):
"""
"""
att_out = dot_product_attention(query, key, value, d_key, q_mask, k_mask,
dropout_rate)
att_out = dot_product_attention(
query,
key,
value,
d_key,
q_mask,
k_mask,
dropout_rate,
mask_cache=mask_cache)
y = query + att_out
if is_layer_norm:
......
......@@ -203,17 +203,17 @@ def make_one_batch_input(data_batches, index):
for i, turn_len in enumerate(every_turn_len_list):
feed_dict["turn_mask_%d" % i] = np.ones(
(batch_size, max_turn_len)).astype("float32")
(batch_size, max_turn_len, 1)).astype("float32")
for row in xrange(batch_size):
feed_dict["turn_mask_%d" % i][row, turn_len[row]:] = 0
feed_dict["turn_mask_%d" % i][row, turn_len[row]:, 0] = 0
feed_dict["response"] = response
feed_dict["response"] = np.expand_dims(feed_dict["response"], axis=-1)
feed_dict["response_mask"] = np.ones(
(batch_size, max_turn_len)).astype("float32")
(batch_size, max_turn_len, 1)).astype("float32")
for row in xrange(batch_size):
feed_dict["response_mask"][row, response_len[row]:] = 0
feed_dict["response_mask"][row, response_len[row]:, 0] = 0
feed_dict["label"] = np.array([data_batches["label"][index]]).reshape(
[-1, 1]).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册