提交 9396c6d9 编写于 作者: Y ying

fix bugs.

上级 3be6c736
......@@ -21,8 +21,6 @@ from ..framework import Variable
from ..param_attr import ParamAttr
from tensor import concat
import pdb
__all__ = [
'fc',
'embedding',
......@@ -1966,7 +1964,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
__check_input(x, y)
helper = LayerHelper('matmul', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='matmul',
inputs={'X': x,
......
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pdb
import layers
__all__ = [
......@@ -163,7 +162,7 @@ def glu(input, dim=-1):
def scaled_dot_product_attention(queries,
keys,
values,
num_heads,
num_heads=1,
dropout_rate=0.):
"""
The dot-product attention.
......@@ -259,9 +258,12 @@ def scaled_dot_product_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
return layers.reshape(x=layers.reshape(
return layers.reshape(
x=trans_x,
shape=[trans_x.shape[0], trans_x[1], trans_x[2] * trans_x[3]]))
shape=map(int, [
trans_x.shape[0], trans_x.shape[1],
trans_x.shape[2] * trans_x.shape[3]
]))
q = __split_heads(queries, num_heads)
k = __split_heads(keys, num_heads)
......@@ -271,10 +273,11 @@ def scaled_dot_product_attention(queries,
scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
product = layers.matmul(x=k, y=scaled_q, transpose_y=True)
attn_scores = layers.reshape(
weights = layers.reshape(
x=layers.reshape(
x=product, shape=[-1, product.shape[-1]], act="softmax"),
shape=product.shape)
ctx_multiheads = layers.matmul(attn_scores, v)
context = __combine_heads(ctx_multiheads)
return context
if dropout_rate:
weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False)
ctx_multiheads = layers.matmul(weights, v)
return __combine_heads(ctx_multiheads)
......@@ -17,8 +17,6 @@ import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import numpy as np
import pdb
class TestMultiheadAttention(unittest.TestCase):
def gen_random_input(self):
......@@ -45,7 +43,7 @@ class TestMultiheadAttention(unittest.TestCase):
append_batch_size=False)
keys.stop_gradient = False
contexts, att_scores = fluid.nets.scaled_dot_product_attention(
contexts = fluid.nets.scaled_dot_product_attention(
queries=queries,
keys=keys,
values=keys,
......@@ -84,20 +82,14 @@ class TestMultiheadAttention(unittest.TestCase):
keys.set(self.keys, place)
self.inputs["keys"] = keys
self.inputs["values"] = values
self.inputs["queries"] = queries
def test_multihead_attention(self):
self.gen_random_input()
self.set_program()
pdb.set_trace()
self.run_program()
expect_output = self.l2_normalize(self.data, axis, epsilon)
# check output
self.assertTrue(np.allclose(self.op_output, expect_output, atol=0.001))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册