未验证 提交 6b587e93 编写于 作者: L Liu-xiandong 提交者: GitHub

Add sparse_attention api, test=develop (#35676)

Add sparse_attention OPs, python api will be added in next pr
上级 bc7e2b92
......@@ -214,7 +214,7 @@ function(op_library TARGET)
foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
......
......@@ -78,7 +78,7 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake)
endif()
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op sparse_attention_op lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
......@@ -94,6 +94,10 @@ if (WITH_GPU OR WITH_ROCM)
endif()
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) )
op_library(sparse_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n")
endif()
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Q",
"(Tensor), The input tensor of query in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
AddInput(
"K",
"(Tensor), The input tensor of key in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
AddInput(
"V",
"(Tensor), The input tensor of value in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
AddInput("Offset",
"(Tensor, default: Tensor<int32>), The input tensor of offset in "
"CSR sparse format, "
"whose dimension : `[batch_size, num_heads, target_len + 1]`.");
AddInput("Columns",
"(Tensor, default: Tensor<int32>), The input tensor of columns in "
"CSR sparse format, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_num]`.");
AddOutput(
"Out",
"(Tensor), The output tensor of result in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
AddOutput("SparseDotSdd",
"(Tensor), The output tensor of result in SparseDotSdd step, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`.")
.AsIntermediate();
AddOutput("Softmax",
"(Tensor), The output tensor of result in Softmax step, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`.")
.AsIntermediate();
AddComment(R"DOC(
Compute the value of the sparse attention module. Its input value includes five tensors.
Q, K, and V represent query, key, and value in the Attention module, respectively.
The CSR format is used to represent the sparsity feature in the Attention module.
The CSR format contains two tensors, offset and columns.
)DOC");
}
};
class SparseAttentionOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "sparse_attention");
OP_INOUT_CHECK(ctx->HasInput("K"), "Input", "K", "sparse_attention");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "sparse_attention");
OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
"sparse_attention");
OP_INOUT_CHECK(ctx->HasInput("Columns"), "Input", "Columns",
"sparse_attention");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sparse_attention");
OP_INOUT_CHECK(ctx->HasOutput("SparseDotSdd"), "Output", "SparseDotSdd",
"sparse_attention");
OP_INOUT_CHECK(ctx->HasOutput("Softmax"), "Output", "Softmax",
"sparse_attention");
auto dims_q = ctx->GetInputDim("Q");
auto dims_k = ctx->GetInputDim("K");
auto dims_v = ctx->GetInputDim("V");
auto dims_columns = ctx->GetInputDim("Columns");
PADDLE_ENFORCE_EQ(dims_q.size(), static_cast<size_t>(4),
platform::errors::InvalidArgument(
"Dimension in query' shapes should be 4."));
PADDLE_ENFORCE_EQ(dims_k.size(), static_cast<size_t>(4),
platform::errors::InvalidArgument(
"Dimension in key' shapes should be 4."));
PADDLE_ENFORCE_EQ(dims_v.size(), static_cast<size_t>(4),
platform::errors::InvalidArgument(
"Dimension in value' shapes should be 4."));
auto batch_size = dims_q[0];
auto num_heads = dims_q[1];
auto M = dims_q[2];
auto N = dims_q[3];
auto sparse_nnz = dims_columns[2];
ctx->SetOutputDim("Out", {batch_size, num_heads, M, N});
ctx->SetOutputDim("SparseDotSdd", {batch_size, num_heads, sparse_nnz});
ctx->SetOutputDim("Softmax", {batch_size, num_heads, sparse_nnz});
ctx->ShareLoD("Q", "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "Q", "K");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class SparseAttentionOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "sparse_attention_grad");
OP_INOUT_CHECK(ctx->HasInput("K"), "Input", "K", "sparse_attention_grad");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "sparse_attention_grad");
OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
"sparse_attention_grad");
OP_INOUT_CHECK(ctx->HasInput("Columns"), "Input", "Columns",
"sparse_attention_grad");
OP_INOUT_CHECK(ctx->HasInput("SparseDotSdd"), "Input", "SparseDotSdd",
"sparse_attention_grad");
OP_INOUT_CHECK(ctx->HasInput("Softmax"), "Input", "Softmax",
"sparse_attention_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "sparse_attention_grad");
auto x_grad_name = framework::GradVarName("Q");
auto y_grad_name = framework::GradVarName("K");
auto z_grad_name = framework::GradVarName("V");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Q"));
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("K"));
}
if (ctx->HasOutput(z_grad_name)) {
ctx->SetOutputDim(z_grad_name, ctx->GetInputDim("V"));
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class SparseAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("sparse_attention_grad");
op->SetInput("Q", this->Input("Q"));
op->SetInput("K", this->Input("K"));
op->SetInput("V", this->Input("V"));
op->SetInput("Offset", this->Input("Offset"));
op->SetInput("Columns", this->Input("Columns"));
op->SetInput("SparseDotSdd", this->Output("SparseDotSdd"));
op->SetInput("Softmax", this->Output("Softmax"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Q"), this->InputGrad("Q"));
op->SetOutput(framework::GradVarName("K"), this->InputGrad("K"));
op->SetOutput(framework::GradVarName("V"), this->InputGrad("V"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sparse_attention, ops::SparseAttentionOp,
ops::SparseAttentionOpMaker,
ops::SparseAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::SparseAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sparse_attention_grad, ops::SparseAttentionOpGrad);
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
import paddle
import os
import re
import platform
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def get_linux_platform():
if platform.system().lower() == 'windows':
return 0
elif platform.system().lower() == 'linux':
return 1
else:
return -1
def get_suitable_env():
if get_cuda_version() >= 11020 and get_linux_platform() == 1:
return True
else:
return False
def softmax(x):
max = np.max(x, axis=1, keepdims=True)
e_x = np.exp(x - max)
sum = np.sum(e_x, axis=1, keepdims=True)
f_x = e_x / sum
return f_x
def get_csr_value(mat, layout, nnz):
row, col = mat.shape[0], mat.shape[1]
value = np.zeros(nnz)
ptr = 0
for i in range(row):
for j in range(col):
if layout[i][j] == 1:
value[ptr] = mat[i][j]
ptr += 1
return value
def ref_sparse_attention(q, k, v, offset, columns):
row, col, nnz = q.shape[0], q.shape[1], columns.shape[0]
mat = np.zeros((row, row))
for cur_row in range(row):
start_ptr = int(offset[cur_row])
end_ptr = int(offset[cur_row + 1])
for ptr in range(start_ptr, end_ptr):
cur_col = int(columns[ptr])
mat[cur_row][cur_col] = 1
a = np.dot(q, k.T) * mat
a_value = get_csr_value(a, mat, nnz)
scaling = float(col)**-0.5
a = scaling * a
for i in range(row):
for j in range(row):
if mat[i][j] == 0:
a[i][j] = float('-inf')
b = softmax(a)
b_value = get_csr_value(b, mat, nnz)
result = np.dot(b, v)
return result, a_value, b_value
def ref_batch_sparse_attention(q, k, v, offset, columns):
batch_size, num_heads, row, col = q.shape
nnz = columns.shape[2]
result = np.zeros((batch_size, num_heads, row, col))
result_sdd = np.zeros((batch_size, num_heads, nnz))
result_softmax = np.zeros((batch_size, num_heads, nnz))
for i in range(batch_size):
for j in range(num_heads):
cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j]
cur_offset, cur_columns = offset[i][j], columns[i][j]
cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q, cur_k, cur_v, cur_offset, cur_columns)
result[i][j] = cur_result
result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax
return result, result_sdd, result_softmax
def init_csr_format(batch_size, num_heads, rows, blocksize):
block_num, block_last = rows / blocksize, rows % blocksize
nnz_num = block_num * blocksize * blocksize + block_last * block_last
offset = np.zeros(rows + 1)
columns = np.zeros(int(nnz_num))
mat = np.zeros((rows, rows))
for i in range(0, rows, blocksize):
for x in range(blocksize):
for y in range(blocksize):
p_x, p_y = i + x, i + y
if (p_x < rows) and (p_y < rows):
mat[p_x][p_y] = 1
p_offset, p_column, count = 0, 0, 0
for i in range(rows):
for j in range(rows):
if mat[i][j] != 0:
count += 1
columns[p_column] = j
p_column += 1
p_offset += 1
offset[p_offset] = count
offset = np.expand_dims(np.expand_dims(offset, 0), 0)
offset = offset.repeat(num_heads, axis=1)
offset = offset.repeat(batch_size, axis=0)
columns = np.expand_dims(np.expand_dims(columns, 0), 0)
columns = columns.repeat(num_heads, axis=1)
columns = columns.repeat(batch_size, axis=0)
return offset, columns
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_suitable_env() == False,
"core is not compiled with CUDA and cuda version need >= 11.2 in windows")
class TestSparseAttentionOp(OpTest):
def config(self):
self.shape = (1, 1, 16, 8)
self.blocksize = 2
self.dtype = "float64"
def setUp(self):
paddle.enable_static()
self.config()
self.op_type = "sparse_attention"
self.place = paddle.CUDAPlace(0)
self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype)
offset, columns = init_csr_format(self.shape[0], self.shape[1],
self.shape[2], self.blocksize)
self.offset = offset.astype('int32')
self.columns = columns.astype('int32')
result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q, self.k, self.v, self.offset, self.columns)
self.inputs = {
'Q': self.q,
'K': self.k,
'V': self.v,
'offset': self.offset,
'columns': self.columns
}
self.outputs = {
'Out': result.astype(self.dtype),
'ResultSdd': result_sdd.astype(self.dtype),
'ResultSoftmax': result_softmax.astype(self.dtype)
}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['Q'], 'Out')
self.check_grad_with_place(self.place, ['K'], 'Out')
self.check_grad_with_place(self.place, ['V'], 'Out')
class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
def config(self):
self.shape = (1, 1, 8, 16)
self.blocksize = 2
self.dtype = "float32"
class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
def config(self):
self.shape = (2, 2, 32, 8)
self.blocksize = 8
self.dtype = "float64"
if __name__ == '__main__':
unittest.main()
......@@ -46,6 +46,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'cudnn_lstm', \
'rnn', \
'lgamma', \
'sparse_attention', \
'svd', \
'matrix_power', \
'solve', \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册