test_fused_attention_op.py 14.0 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
21
import paddle.incubate.nn.functional as incubate_f
L
Li Min 已提交
22 23 24 25 26 27 28
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
29
from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph
30

31
_enable_legacy_dygraph()
32 33

default_main_program().random_seed = 42
L
Li Min 已提交
34 35 36


class TestFusedAttentionOp(OpTest):
37

L
Li Min 已提交
38 39 40
    def setUp(self):
        self.config()
        self.generate_input_data()
L
Li Min 已提交
41 42 43 44 45 46 47 48 49 50 51 52

        self.rtol = 1e-5
        # FIXME(limin29): Because there is a problem with the test precision
        #  on A100, atol is temporarily set to 1e-2, and it will be
        #  changed back after the precision problem is solved.
        self.atol = 1e-2
        # make sure local development precision
        if "V100" in paddle.device.cuda.get_device_name():
            self.atol = 1e-4
        if self.x_type is np.float16:
            self.atol = 1e-1

L
Li Min 已提交
53 54
        paddle.set_default_dtype(self.x_type)
        self.__class__.op_type = "fused_attention"
55 56
        # use autograd to check grad in this unittest.
        self.__class__.no_need_check_grad = True
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        self.q_proj = Linear(self.embed_dim,
                             self.embed_dim,
                             self.weight_attr,
                             bias_attr=self.bias_attr)
        self.k_proj = Linear(self.kdim,
                             self.embed_dim,
                             self.weight_attr,
                             bias_attr=self.bias_attr)
        self.v_proj = Linear(self.vdim,
                             self.embed_dim,
                             self.weight_attr,
                             bias_attr=self.bias_attr)
        self.out_proj = Linear(self.embed_dim,
                               self.embed_dim,
                               self.weight_attr,
                               bias_attr=self.bias_attr)
L
Li Min 已提交
73 74 75 76 77 78 79 80 81
        paddle.set_default_dtype(np.float32)
        self.norm1 = LayerNorm(self.embed_dim)
        self.norm2 = LayerNorm(self.embed_dim)
        paddle.set_default_dtype(self.x_type)
        self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")

    def config(self):
        self.x_type = np.float32
        self.attn_mask_type = np.float64
82
        self.pre_layer_norm = False
83
        self.has_attn_mask = True
84
        self.has_cache_kv = False
L
Li Min 已提交
85 86 87 88
        self.training = True

        self.batch_size = 8
        self.query_length = 128
89
        self.cache_length = 128
L
Li Min 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103
        self.head_dim = 64
        self.num_heads = 16
        self.embed_dim = self.head_dim * self.num_heads

        self.dropout_prob = 0.0
        self.attn_dropout_prob = 0.0
        self.weight_attr = None
        self.bias_attr = None
        self.kdim, self.vdim = self.embed_dim, self.embed_dim
        self.key_length, self.value_length = self.query_length, self.query_length

    def generate_input_data(self):
        self.query = np.random.rand(self.batch_size, self.query_length,
                                    self.embed_dim).astype(self.x_type)
104 105 106 107 108 109 110 111 112 113 114
        out_seq_len = self.key_length
        if self.has_cache_kv:
            assert self.training is False, ValueError(
                'cache_kv can only used in inference')
            self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads,
                                           self.cache_length,
                                           self.head_dim).astype(self.x_type)
            out_seq_len += self.cache_length
        else:
            self.cache_kv = None

115
        if self.has_attn_mask:
116
            # [B, n_head, seq_len, out_seq_len]
117 118 119
            self.attn_mask = np.ones((self.batch_size, self.num_heads,
                                      self.query_length, out_seq_len),
                                     dtype=self.attn_mask_type)
120 121 122 123 124 125 126
            if self.attn_mask_type == np.int64:
                self.attn_mask = np.tril(self.attn_mask)
            elif self.attn_mask_type == np.float64:
                self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
            else:
                raise ValueError(
                    "'attn_mask_type' should be 'int64' or 'float64'.")
L
Li Min 已提交
127
        else:
128
            self.attn_mask = None
L
Li Min 已提交
129 130 131 132 133 134 135 136
        self.key, self.value = self.query, self.query

        self.dout = np.random.random((self.batch_size, self.query_length,
                                      self.embed_dim)).astype(self.x_type)

    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
137 138 139 140 141

        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

142 143 144 145
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

162 163 164 165 166 167 168 169 170 171 172 173
        if self.has_cache_kv:
            # [1, B, n_head, cache_seq_len, head_dim]
            cache_k, cache_v = paddle.split(cache_kv, 2)
            cache_k = paddle.squeeze(cache_k, axis=0)
            cache_v = paddle.squeeze(cache_v, axis=0)
            # [B, n_head, cache_seq_len + seq_len, head_dim]
            # out_seq_len = cache_seq_len + seq_len
            k_out = paddle.concat([cache_k, k_out], axis=-2)
            v_out = paddle.concat([cache_v, v_out], axis=-2)

        # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
        # --> [B, n_head, seq_len, out_seq_len]
174 175 176 177
        qk_out = layers.matmul(x=q_out,
                               y=k_out,
                               transpose_y=True,
                               alpha=self.head_dim**-0.5)
L
Li Min 已提交
178 179 180 181 182 183 184 185 186

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
187 188 189 190
            dropout_out = F.dropout(softmax_out,
                                    self.dropout_prob,
                                    training=self.training,
                                    mode="upscale_in_train")
191 192
            # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, head_dim]
L
Li Min 已提交
193 194 195 196 197 198 199 200 201 202 203 204
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
L
Li Min 已提交
205 206
        else:
            final_out = residual_out
207 208 209 210

        if self.has_cache_kv:
            return final_out

211 212
        paddle.autograd.backward([final_out], [paddle.to_tensor(self.dout)],
                                 retain_graph=True)
213
        return final_out, tensor_query.grad
L
Li Min 已提交
214 215 216

    def GetFusedAttentionOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
217 218 219 220 221 222 223 224
        q_proj_weight = paddle.to_tensor(self.q_proj.weight,
                                         stop_gradient=False)
        k_proj_weight = paddle.to_tensor(self.k_proj.weight,
                                         stop_gradient=False)
        v_proj_weight = paddle.to_tensor(self.v_proj.weight,
                                         stop_gradient=False)
        out_linear_weight = paddle.to_tensor(self.out_proj.weight,
                                             stop_gradient=False)
225 226 227 228 229

        if self.bias_attr is False:
            qkv_bias_tensor = None
            out_linear_bias = None
        else:
230 231 232 233 234 235
            q_proj_bias = paddle.to_tensor(self.q_proj.bias,
                                           stop_gradient=False)
            k_proj_bias = paddle.to_tensor(self.k_proj.bias,
                                           stop_gradient=False)
            v_proj_bias = paddle.to_tensor(self.v_proj.bias,
                                           stop_gradient=False)
236 237 238 239
            qkv_bias = np.concatenate(
                (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
            qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
            qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
240 241
            out_linear_bias = paddle.to_tensor(self.out_proj.bias,
                                               stop_gradient=False)
L
Li Min 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256

        ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
        ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
        ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
        ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)

        q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
        k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
        v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
        qkv_weight = np.concatenate(
            (q_proj_weight, k_proj_weight, v_proj_weight))
        qkv_weight = qkv_weight.reshape(
            (3, self.num_heads, self.head_dim, self.embed_dim))

        x = paddle.to_tensor(self.query, stop_gradient=False)
257 258 259
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
260 261 262 263
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
264 265 266 267 268 269
        qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
        epsilon = 1e-05
        ln2_epsilon = 1e-05

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, x.dtype)
270
        final_out = incubate_f.fused_multi_head_attention(
L
Li Min 已提交
271 272
            x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm,
            ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor,
273
            out_linear_bias, cache_kv, attn_mask, self.dropout_prob,
L
Li Min 已提交
274
            self.attn_dropout_prob, ln2_epsilon)
275 276 277 278

        if self.has_cache_kv:
            return final_out[0], final_out[1]

279 280
        paddle.autograd.backward([final_out], [paddle.to_tensor(self.dout)],
                                 retain_graph=True)
281
        return final_out, x.grad
L
Li Min 已提交
282 283

    def test_fused_attention_op(self):
284 285
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
286 287 288 289 290 291 292 293
        np.testing.assert_allclose(final_out_ref,
                                   final_out.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
        np.testing.assert_allclose(x_grad_ref,
                                   x_grad.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
L
Li Min 已提交
294 295


296
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
297

298
    def config(self):
299
        super().config()
300 301 302
        self.bias_attr = False


303
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
304

305
    def config(self):
306
        super().config()
307
        self.pre_layer_norm = True
308 309 310


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
311

312
    def config(self):
313
        super().config()
314 315
        self.pre_layer_norm = True
        self.has_attn_mask = False
316 317


L
Li Min 已提交
318
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
319

L
Li Min 已提交
320
    def config(self):
321
        super().config()
L
Li Min 已提交
322 323 324
        self.x_type = np.float16

    def test_fused_attention_op(self):
325 326
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
327 328 329 330 331 332 333 334
        np.testing.assert_allclose(final_out_ref,
                                   final_out.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
        np.testing.assert_allclose(x_grad_ref,
                                   x_grad.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
L
Li Min 已提交
335 336


337
class TestFusedAttentionOpCacheKV(TestFusedAttentionOp):
338

339 340 341 342 343 344 345 346 347 348 349
    def config(self):
        super().config()
        self.has_cache_kv = True
        self.training = False
        self.query_length = 1
        self.key_length, self.value_length = 1, 1

    def test_fused_attention_op(self):
        with paddle.no_grad():
            final_out_ref = self.GetBaselineOut()
            final_out, cache_kv_out = self.GetFusedAttentionOut()
350 351 352 353
            np.testing.assert_allclose(final_out_ref,
                                       final_out.numpy(),
                                       rtol=self.rtol,
                                       atol=self.atol)
354 355


L
Li Min 已提交
356 357
if __name__ == "__main__":
    unittest.main()