You need to sign in or sign up before continuing.
提交 dace68ac 编写于 作者: Y ying

Merge branch 'develop' into multihead_attention

...@@ -305,9 +305,9 @@ def get_dict(lang, dict_size, reverse=False): ...@@ -305,9 +305,9 @@ def get_dict(lang, dict_size, reverse=False):
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size)) "wmt16/%s_%d.dict" % (lang, dict_size))
assert (os.path.exists(dict_path), "Word dictionary does not exist. " assert os.path.exists(dict_path), "Word dictionary does not exist. "
"Please invoke paddle.dataset.wmt16.train/test/validation " "Please invoke paddle.dataset.wmt16.train/test/validation first "
"first to build the dictionary.") "to build the dictionary."
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz") tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
return __load_dict(tar_file, dict_size, lang, reverse) return __load_dict(tar_file, dict_size, lang, reverse)
......
...@@ -248,7 +248,8 @@ def scaled_dot_product_attention(queries, ...@@ -248,7 +248,8 @@ def scaled_dot_product_attention(queries,
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, x=x,
shape=list(x.shape[:-1]) + [num_heads, hidden_size // num_heads]) shape=list(x.shape[:-1]) + [num_heads, hidden_size // num_heads])
# permuate the original dimensions into:
# permuate the dimensions into:
# [batch_size, num_heads, max_sequence_len, hidden_size_per_head] # [batch_size, num_heads, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册