提交 7210d0fa 编写于 作者: Y ying

follow comments.

上级 d00eb53a
...@@ -28,6 +28,6 @@ glu ...@@ -28,6 +28,6 @@ glu
scaled_dot_product_attention scaled_dot_product_attention
---------------------------- ----------------------------
.. autofunction:: paddle.v2.fluid.nets.dot_product_attention .. autofunction:: paddle.v2.fluid.nets.scaled_dot_product_attention
:noindex: :noindex:
...@@ -2097,7 +2097,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): ...@@ -2097,7 +2097,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
if len(x_shape) == 1: if len(x_shape) == 1:
x_shape = [1] + x_shape x_shape = [1] + x_shape
if len(y_shape) == 1: if len(y_shape) == 1:
y_shape = [1] + y_shape y_shape = y_shape + [1]
# check the inner 2 dimensions # check the inner 2 dimensions
if transpose_x: if transpose_x:
......
...@@ -306,7 +306,7 @@ def scaled_dot_product_attention(queries, ...@@ -306,7 +306,7 @@ def scaled_dot_product_attention(queries,
[bs, max_sequence_length, num_heads * hidden_dim]. [bs, max_sequence_length, num_heads * hidden_dim].
""" """
if len(x.shape) == 3: return if len(x.shape) == 3: return x
if len(x.shape) != 4: if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.") raise ValueError("Input(x) should be a 4-D Tensor.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册