未验证 提交 3fb495a7 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #1255 from sneaxiy/dam_develop

Modify dam model
...@@ -137,8 +137,8 @@ class Net(object): ...@@ -137,8 +137,8 @@ class Net(object):
r_a_t = fluid.layers.concat(input=r_a_t_stack, axis=1) r_a_t = fluid.layers.concat(input=r_a_t_stack, axis=1)
# sim shape: [batch_size, 2*(stack_num+2), max_turn_len, max_turn_len] # sim shape: [batch_size, 2*(stack_num+2), max_turn_len, max_turn_len]
sim = fluid.layers.matmul(x=t_a_r, y=r_a_t, transpose_y=True) sim = fluid.layers.matmul(
sim = fluid.layers.scale(x=sim, scale=1 / np.sqrt(200.0)) x=t_a_r, y=r_a_t, transpose_y=True, alpha=1 / np.sqrt(200.0))
sim_turns.append(sim) sim_turns.append(sim)
if self.use_stack_op: if self.use_stack_op:
......
...@@ -72,8 +72,8 @@ def dot_product_attention(query, ...@@ -72,8 +72,8 @@ def dot_product_attention(query,
type is dot. type is dot.
""" """
logits = fluid.layers.matmul(x=query, y=key, transpose_y=True) logits = fluid.layers.matmul(
logits = logits * (d_key**(-0.5)) x=query, y=key, transpose_y=True, alpha=d_key**(-0.5))
if (q_mask is not None) and (k_mask is not None): if (q_mask is not None) and (k_mask is not None):
if mask_cache is not None and q_mask.name in mask_cache and k_mask.name in mask_cache[ if mask_cache is not None and q_mask.name in mask_cache and k_mask.name in mask_cache[
...@@ -81,9 +81,12 @@ def dot_product_attention(query, ...@@ -81,9 +81,12 @@ def dot_product_attention(query,
mask, another_mask = mask_cache[q_mask.name][k_mask.name] mask, another_mask = mask_cache[q_mask.name][k_mask.name]
else: else:
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True) mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
another_mask = (1 - mask) * (-2**32 + 1) another_mask = fluid.layers.scale(
mask, scale=2**32 - 1, bias=-1, bias_after_scale=False)
if mask_cache is not None: if mask_cache is not None:
mask_cache[q_mask.name] = dict() if q_mask.name not in mask_cache:
mask_cache[q_mask.name] = dict()
mask_cache[q_mask.name][k_mask.name] = [mask, another_mask] mask_cache[q_mask.name][k_mask.name] = [mask, another_mask]
logits = mask * logits + another_mask logits = mask * logits + another_mask
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册