未验证 提交 88ea8e6f 编写于 作者: L Li Min 提交者: GitHub

Add fused_attention_op: add impl wrappers. (#35903)

上级 7bf84e2d
......@@ -108,7 +108,8 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func, InT **args, OutT *result) {
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
......
......@@ -50,7 +50,7 @@ class AttnLayerNorm {
}
}
void ComputeBackward(const T* x_data const T* y_data,
void ComputeBackward(const T* x_data, const T* d_y_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* mean_data,
const LayerNormParamType<T>* var_data, T* d_x_data,
......
......@@ -34,6 +34,7 @@ namespace cub = hipcub;
#define LAUNCH_BOUNDS(BlockDim)
#endif
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
......@@ -51,11 +52,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename InT, typename OutT, int ShapeSize, int VecSize,
int DATA_PER_THREAD, typename Functor>
__global__ void BroadcastKernelBinary(
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template <typename T>
class AttnMatMul {
public:
// (m, n, k) = bsz_seq, output_size, input_size
AttnMatMul(const platform::CUDADeviceContext& dev_ctx, bool transA,
bool transB, int bsz_seq, int output_size, int input_size,
bool compute_bias)
: dev_ctx_(dev_ctx),
transA_(transA),
transB_(transB),
bsz_seq_(bsz_seq),
output_size_(output_size),
input_size_(input_size),
compute_bias_(compute_bias) {}
~AttnMatMul() {}
void ComputeForward(const T* weight_data, const T* input_data,
const T* bias_data, T* output_data, T* bias_out_data) {
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
if (transA_) {
transA = CblasTrans;
}
if (transB_) {
transB = CblasTrans;
}
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
input_data, weight_data, beta, output_data);
if (compute_bias_) {
// compute output + bias
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
bias_data, bias_out_data);
}
}
void ComputeBackward(const T* input, const T* weight, const T* d_output,
T* d_input, T* d_weight, T* d_bias) {
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
CBLAS_TRANSPOSE dB_transA = CblasNoTrans;
CBLAS_TRANSPOSE dB_transB = CblasNoTrans;
CBLAS_TRANSPOSE dA_transA = CblasNoTrans;
CBLAS_TRANSPOSE dA_transB = CblasNoTrans;
int dB_m = 1;
int dB_n = 1;
int dB_k = 1;
int dA_m = 1;
int dA_n = 1;
int dA_k = 1;
T* dB_input_1_ptr = nullptr;
T* dB_input_2_ptr = nullptr;
T* dB_output_ptr = d_weight;
T* dA_input_1_ptr = nullptr;
T* dA_input_2_ptr = nullptr;
T* dA_output_ptr = d_input;
if (!transA_) {
// fw: gemm-nt
if (transB_) {
// bw: gemm-tn, dB = (dC)^t * A
dB_transA = CblasTrans;
dB_transB = CblasNoTrans;
dB_m = output_size_;
dB_n = input_size_;
dB_k = bsz_seq_;
// bw: gemm-nn, dA = dC * B
dA_transA = CblasNoTrans;
dA_transB = CblasNoTrans;
dA_m = bsz_seq_;
dA_n = input_size_;
dA_k = output_size_;
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, d_output,
input, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
} else { // fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC
dB_transA = CblasTrans;
dB_transB = CblasNoTrans;
dB_m = input_size_;
dB_n = output_size_;
dB_k = bsz_seq_;
// bw: gemm-nt, dA = dC * B^t
dA_transA = CblasNoTrans;
dA_transB = CblasTrans;
dA_m = bsz_seq_;
dA_n = input_size_;
dA_k = output_size_;
blas.GEMM(dB_transA, dB_transB, dB_m, dB_n, dB_k, alpha, input,
d_output, beta, dB_output_ptr);
blas.GEMM(dA_transA, dA_transB, dA_m, dA_n, dA_k, alpha, d_output,
weight, beta, dA_output_ptr);
}
} else if (transB_) {
PADDLE_THROW(platform::errors::InvalidArgument(
"AttnMatMul wrapper do not support (transA=T, transB=T)"
"parameters."));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"AttnMatMul wrapper do not support (transA=T, transB=N)"
"parameters."));
}
if (compute_bias_) {
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
}
}
private:
const platform::CUDADeviceContext& dev_ctx_;
bool transA_;
bool transB_;
int bsz_seq_;
int output_size_;
int input_size_;
int compute_bias_;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/softmax_cudnn_op.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AttnDropoutParam {
public:
AttnDropoutParam() {
is_test_ = false;
dropout_implementation_ = "downgrade_in_infer";
dropout_prob_ = 0.5;
is_upscale_in_train_ = false;
is_fix_seed_ = false;
seed_val_ = 0;
seed_ = nullptr;
}
AttnDropoutParam(bool is_test, const std::string dropout_implementation,
float dropout_prob, bool is_upscale_in_train,
bool is_fix_seed, int seed_val, const Tensor* seed) {
is_test_ = is_test;
dropout_implementation_ = dropout_implementation;
dropout_prob_ = dropout_prob;
is_upscale_in_train_ = is_upscale_in_train;
is_fix_seed_ = is_fix_seed;
seed_val_ = seed_val;
seed_ = seed;
}
bool is_test_;
std::string dropout_implementation_;
float dropout_prob_;
bool is_upscale_in_train_;
bool is_fix_seed_;
int seed_val_;
const Tensor* seed_;
};
template <typename T>
class FMHARef {
public:
FMHARef(const platform::CUDADeviceContext& dev_ctx, int64_t batch_size,
int64_t seq_len, int64_t num_head, int64_t head_dim,
AttnDropoutParam param)
: dev_ctx_(dev_ctx),
batch_size_(batch_size),
seq_len_(seq_len),
num_head_(num_head),
head_dim_(head_dim),
dropout_param_(param) {}
~FMHARef() {}
void ComputeForward(const Tensor& qkv_input_tensor,
const Tensor& src_mask_tensor,
Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor,
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
Tensor* dropout_mask_out_tensor,
Tensor* dropout_out_tensor, Tensor* qktv_out_tensor,
Tensor* fmha_out_tensor) {
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 1, 3, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
int ndims = 5;
std::vector<int> perm_1 = {2, 0, 3, 1, 4};
TransposeGPUKernelDriver<T>(dev_ctx_, ndims, qkv_input_tensor, perm_1,
transpose_2_out_tensor);
T* qkv_data = transpose_2_out_tensor->data<T>();
T* qk_out_data = qk_out_tensor->data<T>();
T* qktv_out_data = qktv_out_tensor->data<T>();
T* softmax_out_data = softmax_out_tensor->data<T>();
T* dropout_out_data = dropout_out_tensor->data<T>();
T* fmha_out_data = fmha_out_tensor->data<T>();
int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
int k_size = q_size;
T* q_ptr = qkv_data;
T* k_ptr = q_ptr + q_size;
T* v_ptr = k_ptr + k_size;
// q*k^t, batched_gemm
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
int gemm_batch_size = batch_size_ * num_head_;
int gemm_m = seq_len_;
int gemm_n = seq_len_;
int gemm_k = head_dim_;
T alpha = static_cast<T>(1.0 / sqrt(head_dim_));
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr,
k_ptr, beta, qk_out_data, gemm_batch_size, stride_a,
stride_b);
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(&src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
int softmax_axis = -1;
if (&src_mask_tensor != nullptr) {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor,
softmax_axis, softmax_out_tensor);
} else {
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out_tensor, softmax_axis,
softmax_out_tensor);
}
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = seq_len_;
alpha = static_cast<T>(1.0);
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) {
DropoutFwGPUKernelDriver<T>(
dev_ctx_, dropout_param_.is_test_,
static_cast<const std::string>(
dropout_param_.dropout_implementation_),
dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_,
dropout_param_.is_fix_seed_, dropout_param_.seed_val_,
static_cast<const Tensor&>(*softmax_out_tensor), dropout_param_.seed_,
dropout_mask_out_tensor, dropout_out_tensor);
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
dropout_out_data, v_ptr, beta, qktv_out_data,
gemm_batch_size, stride_a, stride_b);
} else {
// softmax_out * v, batched_gemm
// output shape: [batch_size, num_heads, seq_len, head_dim]
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
softmax_out_data, v_ptr, beta, qktv_out_data,
gemm_batch_size, stride_a, stride_b);
}
// transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim]
std::vector<int> perm_3 = {0, 2, 1, 3};
ndims = 4;
TransposeGPUKernelDriver<T>(dev_ctx_, ndims, *qktv_out_tensor, perm_3,
fmha_out_tensor);
}
void ComputeBackward(
const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor,
const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor,
const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor,
const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor,
Tensor* qktv_out_grad_tensor, Tensor* dropout_out_grad_tensor,
Tensor* softmax_out_grad_tensor, Tensor* src_mask_out_grad_tensor,
Tensor* qk_out_grad_tensor, Tensor* transpose_2_out_grad_tensor,
Tensor* src_mask_grad_tensor, Tensor* qkv_input_grad_tensor) {
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
int k_size = q_size;
int softmax_axis = -1;
T* qkv_grad_data = transpose_2_out_grad_tensor->data<T>();
T* q_grad_ptr = qkv_grad_data;
T* k_grad_ptr = q_grad_ptr + q_size;
T* v_grad_ptr = k_grad_ptr + k_size;
const T* qkv_data = transpose_2_out_tensor.data<T>();
const T* q_ptr = qkv_data;
const T* k_ptr = q_ptr + q_size;
const T* v_ptr = k_ptr + k_size;
const T* softmax_out_data = softmax_out_tensor.data<T>();
T* softmax_out_grad_data = softmax_out_grad_tensor->data<T>();
const T* dropout_out_data = dropout_out_tensor.data<T>();
T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
T* qktv_out_grad_data = qktv_out_grad_tensor->data<T>();
// transpose bw
int ndims = 4;
std::vector<int> perm_3 = {0, 2, 1, 3};
TransposeGPUKernelDriver<T>(dev_ctx_, ndims, fmha_out_grad_tensor, perm_3,
qktv_out_grad_tensor);
// recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) =
// qktv_out_data(out)
CBLAS_TRANSPOSE transA = CblasTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
int gemm_batch_size = batch_size_ * num_head_;
int gemm_m = seq_len_;
int gemm_n = head_dim_;
int gemm_k = seq_len_;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
int64_t stride_a = gemm_m * gemm_k;
int64_t stride_b = gemm_k * gemm_n;
// bw: dy = x^t * dout
if (dropout_param_.dropout_prob_) {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
dropout_out_data, qktv_out_grad_data, beta, v_grad_ptr,
gemm_batch_size, stride_a, stride_b);
} else {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
softmax_out_data, qktv_out_grad_data, beta, v_grad_ptr,
gemm_batch_size, stride_a, stride_b);
}
// bw: dx = dout * y^t
transA = CblasNoTrans;
transB = CblasTrans;
gemm_m = seq_len_;
gemm_n = seq_len_;
gemm_k = head_dim_;
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
if (dropout_param_.dropout_prob_) {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
qktv_out_grad_data, v_ptr, beta, dropout_out_grad_data,
gemm_batch_size, stride_a, stride_b);
} else {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
qktv_out_grad_data, v_ptr, beta, softmax_out_grad_data,
gemm_batch_size, stride_a, stride_b);
}
// dropout bw
if (dropout_param_.dropout_prob_) {
DropoutGradGPUKernelDriver<T>(
dev_ctx_, static_cast<const std::string>(
dropout_param_.dropout_implementation_),
dropout_param_.dropout_prob_,
static_cast<const Tensor&>(*dropout_out_grad_tensor),
dropout_mask_out_tensor, softmax_out_grad_tensor->numel(),
softmax_out_grad_tensor);
}
if (&src_mask_tensor != nullptr) {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis,
src_mask_out_grad_tensor);
// recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out +
// src_mask
// Special case when dy is not needed and dx doesn't reduce
if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr &&
qk_out_tensor.dims() == src_mask_out_tensor.dims()) {
VLOG(4) << "Special case when dy is not needed and dx doesn't "
"reduce";
framework::TensorCopy(*src_mask_out_grad_tensor, dev_ctx_.GetPlace(),
dev_ctx_, qk_out_grad_tensor);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only used for the backward elementwise_add op when"
"dy is not needed and dx is not reduce"));
return;
}
} else {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis,
qk_out_grad_tensor);
}
T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
alpha = static_cast<T>(1.0 / sqrt(head_dim_));
// recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out
// bw: dy (seq_len * head_dim) = (dout)^t * x
transA = CblasTrans;
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = seq_len_;
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
qk_out_grad_data, q_ptr, beta, k_grad_ptr, gemm_batch_size,
stride_a, stride_b);
// dx (seq_len * head_dim) = dout * y
transA = CblasNoTrans;
transB = CblasNoTrans;
gemm_m = seq_len_;
gemm_n = head_dim_;
gemm_k = seq_len_;
stride_a = gemm_m * gemm_k;
stride_b = gemm_k * gemm_n;
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha,
qk_out_grad_data, k_ptr, beta, q_grad_ptr, gemm_batch_size,
stride_a, stride_b);
// transpose bw
ndims = 5;
std::vector<int> perm_1 = {1, 3, 0, 2, 4};
TransposeGPUKernelDriver<T>(dev_ctx_, ndims, *transpose_2_out_grad_tensor,
perm_1, qkv_input_grad_tensor);
}
private:
const platform::CUDADeviceContext& dev_ctx_;
int64_t batch_size_;
int64_t seq_len_;
int64_t num_head_;
int64_t head_dim_;
AttnDropoutParam dropout_param_;
};
} // namespace operators
} // namespace paddle
......@@ -35,7 +35,6 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册