提交 9396c6d9 编写于 作者: Y ying

fix bugs.

上级 3be6c736
...@@ -21,8 +21,6 @@ from ..framework import Variable ...@@ -21,8 +21,6 @@ from ..framework import Variable
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from tensor import concat from tensor import concat
import pdb
__all__ = [ __all__ = [
'fc', 'fc',
'embedding', 'embedding',
...@@ -1966,7 +1964,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): ...@@ -1966,7 +1964,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
__check_input(x, y) __check_input(x, y)
helper = LayerHelper('matmul', **locals()) helper = LayerHelper('matmul', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op( helper.append_op(
type='matmul', type='matmul',
inputs={'X': x, inputs={'X': x,
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pdb
import layers import layers
__all__ = [ __all__ = [
...@@ -163,7 +162,7 @@ def glu(input, dim=-1): ...@@ -163,7 +162,7 @@ def glu(input, dim=-1):
def scaled_dot_product_attention(queries, def scaled_dot_product_attention(queries,
keys, keys,
values, values,
num_heads, num_heads=1,
dropout_rate=0.): dropout_rate=0.):
""" """
The dot-product attention. The dot-product attention.
...@@ -259,9 +258,12 @@ def scaled_dot_product_attention(queries, ...@@ -259,9 +258,12 @@ def scaled_dot_product_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.") raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
return layers.reshape(x=layers.reshape( return layers.reshape(
x=trans_x, x=trans_x,
shape=[trans_x.shape[0], trans_x[1], trans_x[2] * trans_x[3]])) shape=map(int, [
trans_x.shape[0], trans_x.shape[1],
trans_x.shape[2] * trans_x.shape[3]
]))
q = __split_heads(queries, num_heads) q = __split_heads(queries, num_heads)
k = __split_heads(keys, num_heads) k = __split_heads(keys, num_heads)
...@@ -271,10 +273,11 @@ def scaled_dot_product_attention(queries, ...@@ -271,10 +273,11 @@ def scaled_dot_product_attention(queries,
scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5) scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
product = layers.matmul(x=k, y=scaled_q, transpose_y=True) product = layers.matmul(x=k, y=scaled_q, transpose_y=True)
attn_scores = layers.reshape( weights = layers.reshape(
x=layers.reshape( x=layers.reshape(
x=product, shape=[-1, product.shape[-1]], act="softmax"), x=product, shape=[-1, product.shape[-1]], act="softmax"),
shape=product.shape) shape=product.shape)
ctx_multiheads = layers.matmul(attn_scores, v) if dropout_rate:
context = __combine_heads(ctx_multiheads) weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False)
return context ctx_multiheads = layers.matmul(weights, v)
return __combine_heads(ctx_multiheads)
...@@ -17,8 +17,6 @@ import paddle.v2.fluid as fluid ...@@ -17,8 +17,6 @@ import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core import paddle.v2.fluid.core as core
import numpy as np import numpy as np
import pdb
class TestMultiheadAttention(unittest.TestCase): class TestMultiheadAttention(unittest.TestCase):
def gen_random_input(self): def gen_random_input(self):
...@@ -45,7 +43,7 @@ class TestMultiheadAttention(unittest.TestCase): ...@@ -45,7 +43,7 @@ class TestMultiheadAttention(unittest.TestCase):
append_batch_size=False) append_batch_size=False)
keys.stop_gradient = False keys.stop_gradient = False
contexts, att_scores = fluid.nets.scaled_dot_product_attention( contexts = fluid.nets.scaled_dot_product_attention(
queries=queries, queries=queries,
keys=keys, keys=keys,
values=keys, values=keys,
...@@ -84,20 +82,14 @@ class TestMultiheadAttention(unittest.TestCase): ...@@ -84,20 +82,14 @@ class TestMultiheadAttention(unittest.TestCase):
keys.set(self.keys, place) keys.set(self.keys, place)
self.inputs["keys"] = keys self.inputs["keys"] = keys
self.inputs["values"] = values self.inputs["queries"] = queries
def test_multihead_attention(self): def test_multihead_attention(self):
self.gen_random_input() self.gen_random_input()
self.set_program() self.set_program()
pdb.set_trace()
self.run_program() self.run_program()
expect_output = self.l2_normalize(self.data, axis, epsilon)
# check output
self.assertTrue(np.allclose(self.op_output, expect_output, atol=0.001))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册