test_fused_attention_op.py 13.0 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
21
import paddle.incubate.nn.functional as incubate_f
L
Li Min 已提交
22 23 24 25 26 27 28
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
29 30 31
from paddle.fluid.framework import default_main_program

default_main_program().random_seed = 42
L
Li Min 已提交
32 33 34 35 36 37


class TestFusedAttentionOp(OpTest):
    def setUp(self):
        self.config()
        self.generate_input_data()
38 39 40 41 42 43 44 45 46 47 48 49

        self.rtol = 1e-5
        # FIXME(limin29): Because there is a problem with the test precision
        #  on A100, atol is temporarily set to 1e-2, and it will be
        #  changed back after the precision problem is solved.
        self.atol = 1e-2
        # make sure local development precision
        if "V100" in paddle.device.cuda.get_device_name():
            self.atol = 1e-4
        if self.x_type is np.float16:
            self.atol = 1e-1

L
Li Min 已提交
50 51
        paddle.set_default_dtype(self.x_type)
        self.__class__.op_type = "fused_attention"
52 53
        # use autograd to check grad in this unittest.
        self.__class__.no_need_check_grad = True
L
Li Min 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        self.q_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr)
        self.k_proj = Linear(
            self.kdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr)
        self.v_proj = Linear(
            self.vdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr)
        self.out_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr)
        paddle.set_default_dtype(np.float32)
        self.norm1 = LayerNorm(self.embed_dim)
        self.norm2 = LayerNorm(self.embed_dim)
        paddle.set_default_dtype(self.x_type)
        self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")

    def config(self):
        self.x_type = np.float32
        self.attn_mask_type = np.float64
83
        self.pre_layer_norm = False
84
        self.has_attn_mask = True
85
        self.has_cache_kv = False
L
Li Min 已提交
86 87 88 89
        self.training = True

        self.batch_size = 8
        self.query_length = 128
90
        self.cache_length = 128
L
Li Min 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104
        self.head_dim = 64
        self.num_heads = 16
        self.embed_dim = self.head_dim * self.num_heads

        self.dropout_prob = 0.0
        self.attn_dropout_prob = 0.0
        self.weight_attr = None
        self.bias_attr = None
        self.kdim, self.vdim = self.embed_dim, self.embed_dim
        self.key_length, self.value_length = self.query_length, self.query_length

    def generate_input_data(self):
        self.query = np.random.rand(self.batch_size, self.query_length,
                                    self.embed_dim).astype(self.x_type)
105 106 107 108 109 110 111 112 113 114 115
        out_seq_len = self.key_length
        if self.has_cache_kv:
            assert self.training is False, ValueError(
                'cache_kv can only used in inference')
            self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads,
                                           self.cache_length,
                                           self.head_dim).astype(self.x_type)
            out_seq_len += self.cache_length
        else:
            self.cache_kv = None

116
        if self.has_attn_mask:
117
            # [B, n_head, seq_len, out_seq_len]
118 119
            self.attn_mask = np.ones(
                (self.batch_size, self.num_heads, self.query_length,
120
                 out_seq_len),
121 122 123 124 125 126 127 128
                dtype=self.attn_mask_type)
            if self.attn_mask_type == np.int64:
                self.attn_mask = np.tril(self.attn_mask)
            elif self.attn_mask_type == np.float64:
                self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
            else:
                raise ValueError(
                    "'attn_mask_type' should be 'int64' or 'float64'.")
L
Li Min 已提交
129
        else:
130
            self.attn_mask = None
L
Li Min 已提交
131 132 133 134 135 136 137 138
        self.key, self.value = self.query, self.query

        self.dout = np.random.random((self.batch_size, self.query_length,
                                      self.embed_dim)).astype(self.x_type)

    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
139 140 141 142 143

        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

144 145 146 147
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

164 165 166 167 168 169 170 171 172 173 174 175
        if self.has_cache_kv:
            # [1, B, n_head, cache_seq_len, head_dim]
            cache_k, cache_v = paddle.split(cache_kv, 2)
            cache_k = paddle.squeeze(cache_k, axis=0)
            cache_v = paddle.squeeze(cache_v, axis=0)
            # [B, n_head, cache_seq_len + seq_len, head_dim]
            # out_seq_len = cache_seq_len + seq_len
            k_out = paddle.concat([cache_k, k_out], axis=-2)
            v_out = paddle.concat([cache_v, v_out], axis=-2)

        # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
        # --> [B, n_head, seq_len, out_seq_len]
L
Li Min 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
        qk_out = layers.matmul(
            x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5)

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
            dropout_out = F.dropout(
                softmax_out,
                self.dropout_prob,
                training=self.training,
                mode="upscale_in_train")
192 193
            # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, head_dim]
L
Li Min 已提交
194 195 196 197 198 199 200 201 202 203 204 205
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
L
Li Min 已提交
206 207
        else:
            final_out = residual_out
208 209 210 211

        if self.has_cache_kv:
            return final_out

212 213 214
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
        return final_out, tensor_query.grad
L
Li Min 已提交
215 216 217 218 219 220 221 222 223 224 225

    def GetFusedAttentionOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        q_proj_weight = paddle.to_tensor(
            self.q_proj.weight, stop_gradient=False)
        k_proj_weight = paddle.to_tensor(
            self.k_proj.weight, stop_gradient=False)
        v_proj_weight = paddle.to_tensor(
            self.v_proj.weight, stop_gradient=False)
        out_linear_weight = paddle.to_tensor(
            self.out_proj.weight, stop_gradient=False)
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

        if self.bias_attr is False:
            qkv_bias_tensor = None
            out_linear_bias = None
        else:
            q_proj_bias = paddle.to_tensor(
                self.q_proj.bias, stop_gradient=False)
            k_proj_bias = paddle.to_tensor(
                self.k_proj.bias, stop_gradient=False)
            v_proj_bias = paddle.to_tensor(
                self.v_proj.bias, stop_gradient=False)
            qkv_bias = np.concatenate(
                (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
            qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
            qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
            out_linear_bias = paddle.to_tensor(
                self.out_proj.bias, stop_gradient=False)
L
Li Min 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

        ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
        ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
        ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
        ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)

        q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
        k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
        v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
        qkv_weight = np.concatenate(
            (q_proj_weight, k_proj_weight, v_proj_weight))
        qkv_weight = qkv_weight.reshape(
            (3, self.num_heads, self.head_dim, self.embed_dim))

        x = paddle.to_tensor(self.query, stop_gradient=False)
258 259 260
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
261 262 263 264
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
265 266 267 268 269 270
        qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
        epsilon = 1e-05
        ln2_epsilon = 1e-05

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, x.dtype)
271
        final_out = incubate_f.fused_multi_head_attention(
L
Li Min 已提交
272 273
            x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm,
            ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor,
274
            out_linear_bias, cache_kv, attn_mask, self.dropout_prob,
L
Li Min 已提交
275
            self.attn_dropout_prob, ln2_epsilon)
276 277 278 279

        if self.has_cache_kv:
            return final_out[0], final_out[1]

280 281 282
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
        return final_out, x.grad
L
Li Min 已提交
283 284

    def test_fused_attention_op(self):
285 286
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
L
Li Min 已提交
287
        np.testing.assert_allclose(
288
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol)
289
        np.testing.assert_allclose(
290
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol)
L
Li Min 已提交
291 292


293 294
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
    def config(self):
295
        super().config()
296 297 298
        self.bias_attr = False


299 300
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
    def config(self):
301
        super().config()
302
        self.pre_layer_norm = True
303 304 305 306


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
    def config(self):
307
        super().config()
308 309
        self.pre_layer_norm = True
        self.has_attn_mask = False
310 311


L
Li Min 已提交
312 313
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
    def config(self):
314
        super().config()
L
Li Min 已提交
315 316 317
        self.x_type = np.float16

    def test_fused_attention_op(self):
318 319
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
L
Li Min 已提交
320
        np.testing.assert_allclose(
321
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol)
322
        np.testing.assert_allclose(
323
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol)
L
Li Min 已提交
324 325


326 327 328 329 330 331 332 333 334 335 336 337 338
class TestFusedAttentionOpCacheKV(TestFusedAttentionOp):
    def config(self):
        super().config()
        self.has_cache_kv = True
        self.training = False
        self.query_length = 1
        self.key_length, self.value_length = 1, 1

    def test_fused_attention_op(self):
        with paddle.no_grad():
            final_out_ref = self.GetBaselineOut()
            final_out, cache_kv_out = self.GetFusedAttentionOut()
            np.testing.assert_allclose(
339 340 341 342
                final_out_ref,
                final_out.numpy(),
                rtol=self.rtol,
                atol=self.atol)
343 344


L
Li Min 已提交
345 346
if __name__ == "__main__":
    unittest.main()