提交 aa3e072e 编写于 作者: S sneaxiy

refine dam model

上级 5d07cee8
...@@ -14,8 +14,13 @@ class Net(object): ...@@ -14,8 +14,13 @@ class Net(object):
self._emb_size = emb_size self._emb_size = emb_size
self._stack_num = stack_num self._stack_num = stack_num
self.word_emb_name = "shared_word_emb" self.word_emb_name = "shared_word_emb"
self.use_stack_op = True
self.use_mask_cache = True
self.use_sparse_embedding = True
def create_network(self): def create_network(self):
mask_cache = dict() if self.use_mask_cache else None
turns_data = [] turns_data = []
for i in xrange(self._max_turn_num): for i in xrange(self._max_turn_num):
turn = fluid.layers.data( turn = fluid.layers.data(
...@@ -28,19 +33,22 @@ class Net(object): ...@@ -28,19 +33,22 @@ class Net(object):
for i in xrange(self._max_turn_num): for i in xrange(self._max_turn_num):
turn_mask = fluid.layers.data( turn_mask = fluid.layers.data(
name="turn_mask_%d" % i, name="turn_mask_%d" % i,
shape=[self._max_turn_len], shape=[self._max_turn_len, 1],
dtype="float32") dtype="float32")
turns_mask.append(turn_mask) turns_mask.append(turn_mask)
response = fluid.layers.data( response = fluid.layers.data(
name="response", shape=[self._max_turn_len, 1], dtype="int32") name="response", shape=[self._max_turn_len, 1], dtype="int32")
response_mask = fluid.layers.data( response_mask = fluid.layers.data(
name="response_mask", shape=[self._max_turn_len], dtype="float32") name="response_mask",
shape=[self._max_turn_len, 1],
dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="float32") label = fluid.layers.data(name="label", shape=[1], dtype="float32")
response_emb = fluid.layers.embedding( response_emb = fluid.layers.embedding(
input=response, input=response,
size=[self._vocab_size + 1, self._emb_size], size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=self.word_emb_name, name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1))) initializer=fluid.initializer.Normal(scale=0.1)))
...@@ -57,7 +65,8 @@ class Net(object): ...@@ -57,7 +65,8 @@ class Net(object):
value=Hr, value=Hr,
d_key=self._emb_size, d_key=self._emb_size,
q_mask=response_mask, q_mask=response_mask,
k_mask=response_mask) k_mask=response_mask,
mask_cache=mask_cache)
Hr_stack.append(Hr) Hr_stack.append(Hr)
# context part # context part
...@@ -66,6 +75,7 @@ class Net(object): ...@@ -66,6 +75,7 @@ class Net(object):
Hu = fluid.layers.embedding( Hu = fluid.layers.embedding(
input=turns_data[t], input=turns_data[t],
size=[self._vocab_size + 1, self._emb_size], size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=self.word_emb_name, name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1))) initializer=fluid.initializer.Normal(scale=0.1)))
...@@ -80,7 +90,8 @@ class Net(object): ...@@ -80,7 +90,8 @@ class Net(object):
value=Hu, value=Hu,
d_key=self._emb_size, d_key=self._emb_size,
q_mask=turns_mask[t], q_mask=turns_mask[t],
k_mask=turns_mask[t]) k_mask=turns_mask[t],
mask_cache=mask_cache)
Hu_stack.append(Hu) Hu_stack.append(Hu)
# cross attention # cross attention
...@@ -94,7 +105,8 @@ class Net(object): ...@@ -94,7 +105,8 @@ class Net(object):
value=Hr_stack[index], value=Hr_stack[index],
d_key=self._emb_size, d_key=self._emb_size,
q_mask=turns_mask[t], q_mask=turns_mask[t],
k_mask=response_mask) k_mask=response_mask,
mask_cache=mask_cache)
r_a_t = layers.block( r_a_t = layers.block(
name="r_attend_t_" + str(index), name="r_attend_t_" + str(index),
query=Hr_stack[index], query=Hr_stack[index],
...@@ -102,7 +114,8 @@ class Net(object): ...@@ -102,7 +114,8 @@ class Net(object):
value=Hu_stack[index], value=Hu_stack[index],
d_key=self._emb_size, d_key=self._emb_size,
q_mask=response_mask, q_mask=response_mask,
k_mask=turns_mask[t]) k_mask=turns_mask[t],
mask_cache=mask_cache)
t_a_r_stack.append(t_a_r) t_a_r_stack.append(t_a_r)
r_a_t_stack.append(r_a_t) r_a_t_stack.append(r_a_t)
...@@ -110,25 +123,32 @@ class Net(object): ...@@ -110,25 +123,32 @@ class Net(object):
t_a_r_stack.extend(Hu_stack) t_a_r_stack.extend(Hu_stack)
r_a_t_stack.extend(Hr_stack) r_a_t_stack.extend(Hr_stack)
for index in xrange(len(t_a_r_stack)): if self.use_stack_op:
t_a_r_stack[index] = fluid.layers.unsqueeze( t_a_r = fluid.layers.stack(t_a_r_stack, axis=1)
input=t_a_r_stack[index], axes=[1]) r_a_t = fluid.layers.stack(r_a_t_stack, axis=1)
r_a_t_stack[index] = fluid.layers.unsqueeze( else:
input=r_a_t_stack[index], axes=[1]) for index in xrange(len(t_a_r_stack)):
t_a_r_stack[index] = fluid.layers.unsqueeze(
input=t_a_r_stack[index], axes=[1])
r_a_t_stack[index] = fluid.layers.unsqueeze(
input=r_a_t_stack[index], axes=[1])
t_a_r = fluid.layers.concat(input=t_a_r_stack, axis=1) t_a_r = fluid.layers.concat(input=t_a_r_stack, axis=1)
r_a_t = fluid.layers.concat(input=r_a_t_stack, axis=1) r_a_t = fluid.layers.concat(input=r_a_t_stack, axis=1)
# sim shape: [batch_size, 2*(stack_num+2), max_turn_len, max_turn_len] # sim shape: [batch_size, 2*(stack_num+2), max_turn_len, max_turn_len]
sim = fluid.layers.matmul(x=t_a_r, y=r_a_t, transpose_y=True) sim = fluid.layers.matmul(x=t_a_r, y=r_a_t, transpose_y=True)
sim = fluid.layers.scale(x=sim, scale=1 / np.sqrt(200.0)) sim = fluid.layers.scale(x=sim, scale=1 / np.sqrt(200.0))
sim_turns.append(sim) sim_turns.append(sim)
for index in xrange(len(sim_turns)): if self.use_stack_op:
sim_turns[index] = fluid.layers.unsqueeze( sim = fluid.layers.stack(sim_turns, axis=2)
input=sim_turns[index], axes=[2]) else:
# sim shape: [batch_size, 2*(stack_num+2), max_turn_num, max_turn_len, max_turn_len] for index in xrange(len(sim_turns)):
sim = fluid.layers.concat(input=sim_turns, axis=2) sim_turns[index] = fluid.layers.unsqueeze(
input=sim_turns[index], axes=[2])
# sim shape: [batch_size, 2*(stack_num+2), max_turn_num, max_turn_len, max_turn_len]
sim = fluid.layers.concat(input=sim_turns, axis=2)
# for douban # for douban
final_info = layers.cnn_3d(sim, 32, 16) final_info = layers.cnn_3d(sim, 32, 16)
......
...@@ -52,7 +52,8 @@ def dot_product_attention(query, ...@@ -52,7 +52,8 @@ def dot_product_attention(query,
d_key, d_key,
q_mask=None, q_mask=None,
k_mask=None, k_mask=None,
dropout_rate=None): dropout_rate=None,
mask_cache=None):
"""Dot product layer. """Dot product layer.
Args: Args:
...@@ -75,10 +76,17 @@ def dot_product_attention(query, ...@@ -75,10 +76,17 @@ def dot_product_attention(query,
logits = logits * (d_key**(-0.5)) logits = logits * (d_key**(-0.5))
if (q_mask is not None) and (k_mask is not None): if (q_mask is not None) and (k_mask is not None):
q_mask = fluid.layers.unsqueeze(input=q_mask, axes=[-1]) if mask_cache is not None and q_mask.name in mask_cache and k_mask.name in mask_cache[
k_mask = fluid.layers.unsqueeze(input=k_mask, axes=[-1]) q_mask.name]:
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True) mask, another_mask = mask_cache[q_mask.name][k_mask.name]
logits = mask * logits + (1 - mask) * (-2**32 + 1) else:
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
another_mask = (1 - mask) * (-2**32 + 1)
if mask_cache is not None:
mask_cache[q_mask.name] = dict()
mask_cache[q_mask.name][k_mask.name] = [mask, another_mask]
logits = mask * logits + another_mask
attention = fluid.layers.softmax(logits) attention = fluid.layers.softmax(logits)
if dropout_rate: if dropout_rate:
...@@ -98,12 +106,20 @@ def block(name, ...@@ -98,12 +106,20 @@ def block(name,
q_mask=None, q_mask=None,
k_mask=None, k_mask=None,
is_layer_norm=True, is_layer_norm=True,
dropout_rate=None): dropout_rate=None,
mask_cache=None):
""" """
""" """
att_out = dot_product_attention(query, key, value, d_key, q_mask, k_mask, att_out = dot_product_attention(
dropout_rate) query,
key,
value,
d_key,
q_mask,
k_mask,
dropout_rate,
mask_cache=mask_cache)
y = query + att_out y = query + att_out
if is_layer_norm: if is_layer_norm:
......
...@@ -203,17 +203,17 @@ def make_one_batch_input(data_batches, index): ...@@ -203,17 +203,17 @@ def make_one_batch_input(data_batches, index):
for i, turn_len in enumerate(every_turn_len_list): for i, turn_len in enumerate(every_turn_len_list):
feed_dict["turn_mask_%d" % i] = np.ones( feed_dict["turn_mask_%d" % i] = np.ones(
(batch_size, max_turn_len)).astype("float32") (batch_size, max_turn_len, 1)).astype("float32")
for row in xrange(batch_size): for row in xrange(batch_size):
feed_dict["turn_mask_%d" % i][row, turn_len[row]:] = 0 feed_dict["turn_mask_%d" % i][row, turn_len[row]:, 0] = 0
feed_dict["response"] = response feed_dict["response"] = response
feed_dict["response"] = np.expand_dims(feed_dict["response"], axis=-1) feed_dict["response"] = np.expand_dims(feed_dict["response"], axis=-1)
feed_dict["response_mask"] = np.ones( feed_dict["response_mask"] = np.ones(
(batch_size, max_turn_len)).astype("float32") (batch_size, max_turn_len, 1)).astype("float32")
for row in xrange(batch_size): for row in xrange(batch_size):
feed_dict["response_mask"][row, response_len[row]:] = 0 feed_dict["response_mask"][row, response_len[row]:, 0] = 0
feed_dict["label"] = np.array([data_batches["label"][index]]).reshape( feed_dict["label"] = np.array([data_batches["label"][index]]).reshape(
[-1, 1]).astype("float32") [-1, 1]).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册