未验证 提交 59813de9 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Sparse] add SparseCsrTensor fused_attention kernel and API (#43966)

* [Sparse] add SparseCsrTensor fused_attention kernel and API

* fix comment
上级 7d3b08d9
......@@ -141,6 +141,15 @@
layout : x
data_type : dtype
- api: fused_attention
args : (Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask)
output : Tensor(out), Tensor(softmax)
kernel :
func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr}
layout : sparse_mask
intermediate : softmax
backward: fused_attention_grad
- api: masked_matmul
args : (Tensor x, Tensor y, Tensor mask)
output : Tensor(out)
......
......@@ -127,3 +127,10 @@
output : Tensor(x_grad)
kernel :
func : coo_values_grad{sparse_coo, dense-> sparse_coo}
- backward_api: fused_attention_grad
forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
kernel :
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/fused_attention_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void FusedAttentionCsrGradKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& softmax,
const DenseTensor& dout,
DenseTensor* dquery,
DenseTensor* dkey,
DenseTensor* dvalue) {
PD_THROW(
"Not support CPU kernel of 'sparse.nn.functional.fused_attention' now");
}
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/fused_attention_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax) {
PD_THROW(
"Not support CPU kernel of 'sparse.nn.functional.fused_attention' now");
}
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void FusedAttentionCsrGradKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& softmax,
const DenseTensor& dout,
DenseTensor* dquery,
DenseTensor* dkey,
DenseTensor* dvalue);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax);
} // namespace sparse
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/fused_attention_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/matmul_grad_kernel.h"
namespace phi {
namespace sparse {
template <typename T>
__global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
const T* out_values,
const T* dout_values,
T* dx_values,
int M,
int total_row_num,
float scale,
int batch_nnz) {
// dx = (dout - sum(dout * out)) * out
int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= total_row_num) return;
int cur_batch = row / M;
int crow_idx = cur_batch * (M + 1) + (row % M);
int row_first = cur_batch * batch_nnz + static_cast<int>(out_crows[crow_idx]);
int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]);
if (row_nnz == 0) return;
int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) *
out_values[row_first + idx] / scale;
}
}
template <typename T, typename Context>
void FusedAttentionCsrGradKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& softmax,
const DenseTensor& dout,
DenseTensor* dquery,
DenseTensor* dkey,
DenseTensor* dvalue) {
#if CUDA_VERSION >= 11070
/* Step1: Forward: softmax{CSR} * value{Dense} -> out{Dense}, reuse */
SparseCsrTensor dsoftmax;
CsrDenseMatmulGradKernel<T, Context>(
dev_ctx, softmax, value, dout, &dsoftmax, dvalue);
/* Step2: Calculate grad of sdd_result, manualy not reuse */
SparseCsrTensor d_sdd_result;
EmptyLikeCsrKernel<T, Context>(dev_ctx, dsoftmax, &d_sdd_result);
auto q_dim = query.dims();
auto q_rank = q_dim.size();
int total_row_num = 1;
int batch_num = 1;
for (int i = 0; i < q_rank - 1; ++i) {
total_row_num *= q_dim[i];
if (i < q_rank - 2) {
batch_num *= q_dim[i];
}
}
int M = q_dim[q_rank - 2];
int N = q_dim[q_rank - 1];
int batch_nnz = softmax.nnz() / batch_num;
dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
AttnSoftmaxGpuGradKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
softmax.non_zero_crows().data<int64_t>(),
softmax.non_zero_elements().data<T>(),
dsoftmax.mutable_non_zero_elements()->data<T>(),
d_sdd_result.mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
std::sqrt(N),
batch_nnz);
/* Step3: Forward: query{Dense} * key'{Dense} -> sdd_result{SparseCsr} */
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
// dquery{Dense} = d_sdd_result{SparseCsr} * key{Dense} //
dquery->Resize(query.dims());
dev_ctx.template Alloc<T>(dquery);
sparse_blas.SPMM(false,
false,
static_cast<T>(1.f),
d_sdd_result,
key,
static_cast<T>(0.f),
dquery);
// dkey{Dense} = d_sdd_result'{SparseCsr} * query{Dense} //
dkey->Resize(key.dims());
dev_ctx.template Alloc<T>(dkey);
sparse_blas.SPMM(true,
false,
static_cast<T>(1.f),
d_sdd_result,
query,
static_cast<T>(0.f),
dkey);
#else
PADDLE_THROW(
phi::errors::Unimplemented("backward of 'sparse.nn.functional.attention' "
"use 'cusparseCsrSetStridedBatch', which is "
"completed supported from CUDA 11.7"));
#endif
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(fused_attention_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::FusedAttentionCsrGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/fused_attention_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/matmul_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
namespace phi {
namespace sparse {
#define PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, size, HINT, ...) \
case size: { \
constexpr int HINT = size; \
__VA_ARGS__(); \
break; \
}
#define VISIT_ATTN_SFOTMAX(SIZE, NAME, ...) \
[&] { \
const auto& __size__ = SIZE; \
switch (__size__) { \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 1, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 2, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 3, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 4, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 8, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 12, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 16, KBufferSize, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for columns>512 "); \
} \
}()
template <typename T, int BufferSize>
__global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
const int64_t* x_cols,
const T* x_values,
const T* kp_mask,
const T* attn_mask,
T* out_values,
int M,
int total_row_num,
float scale,
int num_heads,
int batch_nnz) {
// out = exp(x-x_max) / sum(exp(x-x_max))
int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= total_row_num) return;
int cur_batch = row / M;
int cur_row = row % M;
int crow_idx = cur_batch * (M + 1) + cur_row;
int row_first = cur_batch * batch_nnz + static_cast<int>(x_crows[crow_idx]);
int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]);
if (row_nnz == 0) return;
T buffer[BufferSize] = {0};
int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T max_val = -std::numeric_limits<T>::infinity();
for (int i = 0; i < kIteration; ++i) {
bool mask = false;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
int col_idx = static_cast<int>(x_cols[row_first + idx]);
if (kp_mask != nullptr &&
kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) {
mask = true;
}
if (attn_mask != nullptr && attn_mask[cur_row * M + col_idx] == 0) {
mask = true;
}
if (!mask) {
buffer[i] = x_values[row_first + idx] / scale;
if (buffer[i] > max_val) {
max_val = buffer[i];
}
}
}
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
if (buffer[i]) {
T exp = functor(buffer[i] - row_max_val);
exp_sum += exp;
buffer[i] = exp;
}
}
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
if (buffer[i]) {
out_values[row_first + idx] = buffer[i] / row_exp_sum;
} else {
out_values[row_first + idx] = static_cast<T>(0);
}
}
}
template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax) {
#if CUDA_VERSION >= 11070
/* Check Shape */
auto q_dim = query.dims();
auto q_rank = q_dim.size();
int total_row_num = 1;
int batch_num = 1;
for (int i = 0; i < q_rank - 1; ++i) {
total_row_num *= q_dim[i];
if (i < q_rank - 2) {
batch_num *= q_dim[i];
}
}
int M = q_dim[q_rank - 2];
int N = q_dim[q_rank - 1];
PADDLE_ENFORCE_EQ(query.dims().size(),
4,
phi::errors::InvalidArgument(" 'query' must be 4D Tensor"));
PADDLE_ENFORCE_EQ(key.dims().size(),
4,
phi::errors::InvalidArgument(" 'key' must be 4D Tensor"));
PADDLE_ENFORCE_EQ(value.dims().size(),
4,
phi::errors::InvalidArgument(" 'value' must be 4D Tensor"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims().size(),
3,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[0],
q_dim[0] * q_dim[1],
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[1],
M,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[2],
M,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
key_padding_mask.dims().size(),
2,
phi::errors::InvalidArgument(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"));
PADDLE_ENFORCE_EQ(
key_padding_mask.dims()[0],
q_dim[0],
phi::errors::InvalidArgument(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"));
PADDLE_ENFORCE_EQ(
key_padding_mask.dims()[1],
M,
phi::errors::InvalidArgument(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"));
PADDLE_ENFORCE_EQ(attn_mask.dims().size(),
2,
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(attn_mask.dims()[0],
M,
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(attn_mask.dims()[1],
M,
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
/* Step1: SDD Matmul, reuse */
SparseCsrTensor sdd_result;
EmptyLikeCsrKernel<T, Context>(dev_ctx, sparse_mask, &sdd_result);
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SDDMM(false,
true,
static_cast<T>(1),
query,
key,
static_cast<T>(0),
&sdd_result);
/* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */
EmptyLikeCsrKernel<T, Context>(dev_ctx, sdd_result, softmax);
int buffer_size;
if (M < 128) {
buffer_size = (M + 32 - 1) / 32;
} else {
buffer_size = ((M + 128 - 1) / 128) * 4;
}
dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
int batch_nnz = sdd_result.nnz() / batch_num;
VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] {
AttnSoftmaxGpuKernel<T, KBufferSize><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(),
key_padding_mask.data<T>(),
attn_mask.data<T>(),
softmax->mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
std::sqrt(N),
q_dim[1],
batch_nnz);
});
/* Step3: DSD Matmul, reuse */
softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]}));
CsrDenseMatmulKernel<T, Context>(dev_ctx, *softmax, value, out);
#else
PADDLE_THROW(
phi::errors::Unimplemented("forward of 'sparse.nn.functional.attention' "
"use 'cusparseCsrSetStridedBatch', which is "
"completed supported from CUDA 11.7"));
#endif
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(fused_attention_csr,
GPU,
ALL_LAYOUT,
phi::sparse::FusedAttentionCsrKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
......@@ -76,7 +76,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
out_dim_vec[y_ndims - 1] = ydim_vec[y_ndims - 1];
MetaTensor meta_out(out);
meta_out.set_dims(phi::make_ddim(out_dim_vec));
meta_out.set_dtype(x.non_zero_elements().dtype());
meta_out.set_dtype(y.dtype());
dev_ctx.template Alloc<T>(out);
......
......@@ -52,8 +52,9 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows,
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;
if (max_val < x_values[row_first + idx]) {
max_val = x_values[row_first + idx];
T val = x_values[row_first + idx];
if (val > max_val) {
max_val = val;
}
}
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import re
import copy
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11070,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
)
class TestSparseAttentionAPI1(unittest.TestCase):
def setUp(self):
self.batch_size = 16
self.num_heads = 16
self.seq_len = 128
self.head_dim = 16
self.dtype = 'float64'
def test_dygraph(self):
with _test_eager_guard():
self.shape = [
self.batch_size, self.num_heads, self.seq_len, self.head_dim
]
query = paddle.rand(self.shape, self.dtype)
key = paddle.rand(self.shape, self.dtype)
value = paddle.rand(self.shape, self.dtype)
query.stop_gradient = False
key.stop_gradient = False
value.stop_gradient = False
mask = paddle.nn.functional.dropout(paddle.ones(
[self.seq_len, self.seq_len]),
mode='downscale_in_infer')
mask = mask.expand(
[self.batch_size, self.num_heads, self.seq_len, self.seq_len])
sp_mask = mask.reshape([-1, self.seq_len,
self.seq_len]).to_sparse_csr()
kp_mask = paddle.randint(
0, 2, [self.batch_size, self.seq_len]).astype(self.dtype)
attn_mask = paddle.randint(
0, 2, [self.seq_len, self.seq_len]).astype(self.dtype)
sdd = paddle.matmul(query, key, False, True) / math.sqrt(
float(self.head_dim))
sdd = sdd + (
(mask * kp_mask.unsqueeze([1, 2]) * attn_mask) - 1.0) * 1e9
softmax = paddle.nn.functional.softmax(sdd)
output = paddle.matmul(softmax, value)
output.backward()
query_cp = copy.deepcopy(query)
key_cp = copy.deepcopy(key)
value_cp = copy.deepcopy(value)
query_cp.stop_gradient = False
key_cp.stop_gradient = False
value_cp.stop_gradient = False
output_cp = paddle.incubate.sparse.nn.functional.attention(
query_cp, key_cp, value_cp, sp_mask, kp_mask, attn_mask)
output_cp.backward()
self.assertTrue(np.allclose(output_cp.numpy(), output.numpy()))
self.assertTrue(
np.allclose(query_cp.grad.numpy(), query.grad.numpy()))
self.assertTrue(np.allclose(key_cp.grad.numpy(), key.grad.numpy()))
self.assertTrue(
np.allclose(value_cp.grad.numpy(), value.grad.numpy()))
class TestSparseAttentionAPI2(TestSparseAttentionAPI1):
def setUp(self):
self.batch_size = 16
self.num_heads = 16
self.seq_len = 128
self.head_dim = 32
self.dtype = 'float64'
class TestSparseAttentionAPI3(TestSparseAttentionAPI1):
def setUp(self):
self.batch_size = 16
self.num_heads = 16
self.seq_len = 512
self.head_dim = 16
self.dtype = 'float64'
class TestSparseAttentionAPI4(TestSparseAttentionAPI1):
def setUp(self):
self.batch_size = 16
self.num_heads = 16
self.seq_len = 512
self.head_dim = 32
self.dtype = 'float64'
class TestSparseAttentionAPI5(TestSparseAttentionAPI1):
def setUp(self):
self.batch_size = 16
self.num_heads = 16
self.seq_len = 512
self.head_dim = 64
self.dtype = 'float64'
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,7 @@
from .conv import conv3d # noqa: F401
from .conv import subm_conv3d # noqa: F401
from .transformer import attention # noqa: F401
from .pooling import max_pool3d # noqa: F401
from .activation import relu # noqa: F401
from .activation import softmax # noqa: F401
......@@ -24,4 +25,5 @@ __all__ = [
'max_pool3d',
'relu',
'softmax',
'attention',
]
......@@ -14,7 +14,7 @@
__all__ = []
from paddle import _C_ops, in_dynamic_mode
from paddle import _C_ops
from paddle.fluid.framework import dygraph_only
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = []
from paddle import _C_ops
from paddle.fluid.framework import dygraph_only
@dygraph_only
def attention(query,
key,
value,
sparse_mask,
key_padding_mask,
attn_mask,
name=None):
"""
Note:
This API is only used from ``CUDA 11.7`` .
SparseCsrTensor is used to store the intermediate result of Attention matrix
in Transformer module, which can reduce memory usage and improve performance.
``sparse_mask`` express the sparse layout in CSR format.
The calculation equation is:
.. math::
result = softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The shape of the three parameters are: `[batch_size, num_heads, seq_len, head_dim]`, and
``d`` represents ``head_dim`` .
Args:
query(DenseTensor): `query` in the Attention module. 4D Tensor with float32 or float64.
key(DenseTensor): `key` in the Attention module. 4D Tensor with float32 or float64.
value(DenseTensor): `value` in the Attention module. 4D Tensor with float32 or float64.
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same.
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64.
attn_mask(DenseTensor):The attention mask tensor in the Attention module.
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
Returns:
4D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. dtype is same with input.
Examples:
.. code-block:: python
import paddle
batch_size = 16
num_heads = 16
seq_len = 512
head_dim = 32
query = paddle.rand([batch_size, num_heads, seq_len, head_dim])
key = paddle.rand([batch_size, num_heads, seq_len, head_dim])
value = paddle.rand([batch_size, num_heads, seq_len, head_dim])
query.stop_gradient = False
key.stop_gradient = False
value.stop_gradient = False
mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size, num_heads, seq_len, seq_len])
sp_mask = mask.reshape([-1, seq_len, seq_len]).to_sparse_csr()
kp_mask = paddle.randint(0, 2, [batch_size, seq_len])
attn_mask = paddle.randint(0, 2, [seq_len, seq_len])
output = paddle.incubate.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask)
output.backward()
"""
return _C_ops.final_state_sparse_fused_attention(query, key, value,
sparse_mask,
key_padding_mask,
attn_mask)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册