未验证 提交 d1a33bc7 编写于 作者: L Li Min 提交者: GitHub

Add feed_forward for fused attention op. (#34945)

Describe

Add feed_forward for fused attention op.
(1) Encapsulate matmul impl (forward and backward) used in attention op.
(2) Implement bias_add (forward and backward) used in attention op.
上级 fa6c59a4
......@@ -157,6 +157,7 @@ cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_co
if (WITH_GPU)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator)
nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3)
nv_test(feed_forward_test SRCS feed_forward_test.cu DEPS elementwise_add_op matmul_op tensor generator)
elseif(WITH_ROCM)
hip_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator)
hip_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/attn_feed_forward.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/float16.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
USE_OP(matmul);
USE_OP(elementwise_add);
// get paddle matmul op results as baseline
template <typename T>
void GetLinearOp(const std::vector<T> &x, const std::vector<T> &y,
const framework::DDim &x_dim, const framework::DDim &y_dim,
const platform::CUDADeviceContext &ctx, bool transpose_a,
bool transpose_b, float alpha, std::vector<T> *out) {
framework::Scope scope;
auto var_x = scope.Var("X");
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
auto var_out = scope.Var("Out");
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
tensor_x->Resize(x_dim);
tensor_y->Resize(y_dim);
tensor_out->Resize({x_dim[0], x_dim[1], y_dim[0]});
auto x_ptr = tensor_x->mutable_data<T>(ctx.GetPlace());
auto y_ptr = tensor_y->mutable_data<T>(ctx.GetPlace());
auto z_ptr = tensor_out->mutable_data<T>(ctx.GetPlace());
auto size_x = static_cast<size_t>(framework::product(x_dim));
auto size_y = static_cast<size_t>(framework::product(y_dim));
auto size_z = x_dim[0] * x_dim[1] * y_dim[0];
cudaMemcpy(x_ptr, x.data(), size_x * sizeof(T), cudaMemcpyHostToDevice);
cudaMemcpy(y_ptr, y.data(), size_y * sizeof(T), cudaMemcpyHostToDevice);
framework::AttributeMap attrs;
attrs.insert({"transpose_X", transpose_a});
attrs.insert({"transpose_Y", transpose_b});
attrs.insert({"alpha", alpha});
auto op = framework::OpRegistry::CreateOp(
"matmul", {{"X", {"X"}}, {"Y", {"Y"}}}, {{"Out", {"Out"}}}, attrs);
op->Run(scope, ctx.GetPlace());
cudaMemcpy(out->data(), z_ptr, size_z * sizeof(T), cudaMemcpyDeviceToHost);
ctx.Wait();
}
// get paddle elementwise_add op results as baseline
template <typename T>
void GetElementwiseAddOp(const std::vector<T> &x, const std::vector<T> &y,
const int bsz_seq, const int output_size,
const platform::CUDADeviceContext &ctx,
std::vector<T> *out) {
framework::Scope scope;
auto var_x = scope.Var("X");
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
auto var_out = scope.Var("Out");
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
tensor_x->Resize({bsz_seq, output_size});
tensor_y->Resize({output_size});
tensor_out->Resize({bsz_seq, output_size});
auto x_ptr = tensor_x->mutable_data<T>(ctx.GetPlace());
auto y_ptr = tensor_y->mutable_data<T>(ctx.GetPlace());
auto z_ptr = tensor_out->mutable_data<T>(ctx.GetPlace());
auto size_x = bsz_seq * output_size;
auto size_y = output_size;
auto size_z = bsz_seq * output_size;
cudaMemcpy(x_ptr, x.data(), size_x * sizeof(T), cudaMemcpyHostToDevice);
cudaMemcpy(y_ptr, y.data(), size_y * sizeof(T), cudaMemcpyHostToDevice);
framework::AttributeMap attrs;
auto op = framework::OpRegistry::CreateOp("elementwise_add",
{{"X", {"X"}}, {"Y", {"Y"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(scope, ctx.GetPlace());
cudaMemcpy(out->data(), z_ptr, size_z * sizeof(T), cudaMemcpyDeviceToHost);
ctx.Wait();
}
// get paddle matmul_grad op results as baseline
template <typename T>
void GetLinearOpGrad(const std::vector<T> &x_vec, const std::vector<T> &y_vec,
const std::vector<T> &dout_vec,
const framework::DDim &x_dim, const framework::DDim &y_dim,
const framework::DDim &out_dim,
const platform::CUDADeviceContext &ctx, bool transpose_a,
bool transpose_b, float alpha, std::vector<T> *dinput_vec,
std::vector<T> *dweight_vec) {
framework::Scope scope;
auto var_x = scope.Var("X");
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
auto var_dout = scope.Var("DOut");
auto tensor_dout = var_dout->GetMutable<framework::LoDTensor>();
tensor_x->Resize(x_dim);
tensor_y->Resize(y_dim);
tensor_dout->Resize(out_dim);
auto var_dx = scope.Var("DX");
auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();
auto var_dy = scope.Var("DY");
auto tensor_dy = var_dy->GetMutable<framework::LoDTensor>();
tensor_dx->Resize(x_dim);
tensor_dy->Resize(y_dim);
auto x_ptr = tensor_x->mutable_data<T>(ctx.GetPlace());
auto y_ptr = tensor_y->mutable_data<T>(ctx.GetPlace());
auto dout_ptr = tensor_dout->mutable_data<T>(ctx.GetPlace());
auto dinput_ptr = tensor_dx->mutable_data<T>(ctx.GetPlace());
auto dweight_ptr = tensor_dy->mutable_data<T>(ctx.GetPlace());
auto size_x = static_cast<size_t>(framework::product(x_dim));
auto size_y = static_cast<size_t>(framework::product(y_dim));
auto size_z = x_dim[0] * x_dim[1] * y_dim[0];
cudaMemcpy(x_ptr, x_vec.data(), size_x * sizeof(T), cudaMemcpyHostToDevice);
cudaMemcpy(y_ptr, y_vec.data(), size_y * sizeof(T), cudaMemcpyHostToDevice);
cudaMemcpy(dout_ptr, dout_vec.data(), size_z * sizeof(T),
cudaMemcpyHostToDevice);
bool use_mkldnn = false;
std::vector<int> fused_reshape_X = {};
std::vector<int> fused_reshape_Y = {};
std::vector<int> fused_reshape_Out = {};
std::vector<int> fused_transpose_X = {};
std::vector<int> fused_transpose_Y = {};
std::vector<int> fused_transpose_Out = {};
bool use_quantizer = false, force_fp32_output = false;
std::string mkldnn_data_type = "float32";
float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0;
framework::AttributeMap attrs;
attrs.insert({"transpose_X", transpose_a});
attrs.insert({"transpose_Y", transpose_b});
attrs.insert({"alpha", alpha});
attrs.insert({"use_mkldnn", use_mkldnn});
attrs.insert({"fused_reshape_X", fused_reshape_X});
attrs.insert({"fused_reshape_Y", fused_reshape_Y});
attrs.insert({"fused_reshape_Out", fused_reshape_Out});
attrs.insert({"fused_transpose_X", fused_transpose_X});
attrs.insert({"fused_transpose_Y", fused_transpose_Y});
attrs.insert({"fused_transpose_Out", fused_transpose_Out});
attrs.insert({"use_quantizer", use_quantizer});
attrs.insert({"mkldnn_data_type", mkldnn_data_type});
attrs.insert({"Scale_x", Scale_x});
attrs.insert({"Scale_y", Scale_y});
attrs.insert({"Scale_out", Scale_out});
attrs.insert({"force_fp32_output", force_fp32_output});
auto op = framework::OpRegistry::CreateOp(
"matmul_grad", {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}},
{{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs);
op->Run(scope, ctx.GetPlace());
cudaMemcpy(dinput_vec->data(), dinput_ptr, size_x * sizeof(T),
cudaMemcpyDeviceToHost);
cudaMemcpy(dweight_vec->data(), dweight_ptr, size_y * sizeof(T),
cudaMemcpyDeviceToHost);
ctx.Wait();
}
// get paddle elementwise_add_grad op results as baseline
template <typename T>
void GetElementwiseAddOpGrad(const std::vector<T> &dout_vec, const int bsz_seq,
const int output_size,
const platform::CUDADeviceContext &ctx,
std::vector<T> *dy_vec) {
framework::Scope scope;
auto var_x = scope.Var("X");
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
auto var_dout = scope.Var("DOut");
auto tensor_dout = var_dout->GetMutable<framework::LoDTensor>();
tensor_x->Resize({bsz_seq, output_size});
tensor_y->Resize({output_size});
tensor_dout->Resize({bsz_seq, output_size});
auto var_dx = scope.Var("DX");
auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();
auto var_dy = scope.Var("DY");
auto tensor_dy = var_dy->GetMutable<framework::LoDTensor>();
tensor_dx->Resize({bsz_seq, output_size});
tensor_dy->Resize({output_size});
auto dout_ptr = tensor_dout->mutable_data<T>(ctx.GetPlace());
auto tensor_dy_ptr = tensor_dy->mutable_data<T>(ctx.GetPlace());
auto size_z = static_cast<size_t>(bsz_seq * output_size);
cudaMemcpy(dout_ptr, dout_vec.data(), size_z * sizeof(T),
cudaMemcpyHostToDevice);
int axis = -1;
bool use_mkldnn = false, use_quantizer = false;
std::string mkldnn_data_type = "float32";
std::string x_data_format = "", y_data_format = "";
float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0;
framework::AttributeMap attrs;
attrs.insert({"axis", axis});
attrs.insert({"use_mkldnn", use_mkldnn});
attrs.insert({"x_data_format", x_data_format});
attrs.insert({"y_data_format", y_data_format});
attrs.insert({"use_quantizer", use_quantizer});
attrs.insert({"mkldnn_data_type", mkldnn_data_type});
attrs.insert({"Scale_x", Scale_x});
attrs.insert({"Scale_y", Scale_y});
attrs.insert({"Scale_out", Scale_out});
auto op = framework::OpRegistry::CreateOp(
"elementwise_add_grad",
{{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}},
{{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs);
op->Run(scope, ctx.GetPlace());
auto size_y = static_cast<size_t>(output_size);
cudaMemcpy(dy_vec->data(), tensor_dy_ptr, size_y * sizeof(T),
cudaMemcpyDeviceToHost);
ctx.Wait();
}
template <typename T>
class TestFeedForward {
public:
TestFeedForward() {
batch_size_ = 16;
seq_len_ = 128;
num_head_ = 16;
dim_head_ = 64;
dim_embed_ = 1024;
has_bias_ = false;
}
TestFeedForward(int batch_size, int seq_len, int num_head, int dim_head,
int dim_embed, bool has_bias) {
batch_size_ = batch_size;
seq_len_ = seq_len;
num_head_ = num_head;
dim_head_ = dim_head;
dim_embed_ = dim_embed;
has_bias_ = has_bias;
}
~TestFeedForward() { delete ctx_; }
void SetUp() {
bsz_seq_ = batch_size_ * seq_len_;
output_size_ = 3 * num_head_ * dim_head_;
input_size_ = dim_embed_;
ctx_ = new platform::CUDADeviceContext(place_);
size_src_ = bsz_seq_ * dim_embed_; // src: [bs, seq_len, em_dim]
size_weight_ = dim_embed_ * output_size_; // weight: [output_size, em_dim]
size_output_ =
bsz_seq_ * output_size_; // output: [bs, seq_len, output_size]
size_bias_ = output_size_;
base_out_vec_.resize(size_output_);
base_bias_out_vec_.resize(size_output_);
base_dinput_vec_.resize(size_src_);
base_dweight_vec_.resize(size_weight_);
base_dbias_vec_.resize(size_bias_);
src_vec_.resize(size_src_);
weight_vec_.resize(size_weight_);
bias_vec_.resize(size_bias_);
doutput_vec_.resize(size_output_);
std::default_random_engine random(time(NULL));
std::uniform_real_distribution<float> dis(0.0, 1.0);
for (int i = 0; i < size_src_; i++) {
src_vec_[i] = static_cast<T>(dis(random));
}
for (int i = 0; i < size_weight_; i++) {
weight_vec_[i] = static_cast<T>(dis(random));
}
for (int i = 0; i < size_bias_; i++) {
bias_vec_[i] = static_cast<T>(dis(random));
}
for (int i = 0; i < size_output_; i++) {
doutput_vec_[i] = static_cast<T>(dis(random));
}
framework::TensorFromVector<T>(src_vec_, *ctx_, &src_);
src_.Resize({batch_size_, seq_len_, dim_embed_});
framework::TensorFromVector<T>(weight_vec_, *ctx_, &weight_);
weight_.Resize({output_size_, dim_embed_});
out_.Resize({batch_size_, seq_len_, output_size_});
out_.mutable_data<T>(place_);
if (has_bias_) {
framework::TensorFromVector<T>(bias_vec_, *ctx_, &bias_);
bias_.Resize({output_size_});
bias_out_.Resize({batch_size_, seq_len_, output_size_});
bias_out_.mutable_data<T>(place_);
}
framework::TensorFromVector<T>(doutput_vec_, *ctx_, &doutput_);
doutput_.Resize({batch_size_, seq_len_, output_size_});
dinput_.Resize({batch_size_, seq_len_, dim_embed_});
dinput_.mutable_data<T>(place_);
dweight_.Resize({output_size_, dim_embed_});
dweight_.mutable_data<T>(place_);
if (has_bias_) {
dbias_.Resize({output_size_});
dbias_.mutable_data<T>(place_);
}
}
void BaselineForward() {
bool transpose_a = false, transpose_b = true;
float alpha = 1;
GetLinearOp(src_vec_, weight_vec_, src_.dims(), weight_.dims(), *ctx_,
transpose_a, transpose_b, alpha, &base_out_vec_);
if (has_bias_) {
GetElementwiseAddOp(base_out_vec_, bias_vec_, bsz_seq_, output_size_,
*ctx_, &base_bias_out_vec_);
}
ctx_->Wait();
}
// get forward results of feedforward.
void FusedForward() {
T *p_weight = weight_.data<T>();
T *p_src = src_.data<T>();
T *p_output = out_.data<T>();
T *p_bias = nullptr;
T *p_bias_output = nullptr;
if (has_bias_) {
p_bias = bias_.data<T>();
p_bias_output = bias_out_.data<T>();
}
auto qkv_compute = paddle::operators::FeedForward<T>(
*ctx_, bsz_seq_, output_size_, input_size_, has_bias_);
qkv_compute.ComputeForward(p_weight, p_src, p_bias, p_output,
p_bias_output);
ctx_->Wait();
}
void BaselineBackward() {
bool transpose_a = false, transpose_b = true;
float alpha = 1;
GetLinearOpGrad(src_vec_, weight_vec_, doutput_vec_, src_.dims(),
weight_.dims(), out_.dims(), *ctx_, transpose_a,
transpose_b, alpha, &base_dinput_vec_, &base_dweight_vec_);
if (has_bias_) {
GetElementwiseAddOpGrad(doutput_vec_, bsz_seq_, output_size_, *ctx_,
&base_dbias_vec_);
}
ctx_->Wait();
}
// get backward results of feedforward.
void FusedBackward() {
T *p_weight = weight_.data<T>();
T *p_src = src_.data<T>();
T *p_doutput = doutput_.data<T>();
T *p_dinput = dinput_.data<T>();
T *p_dweight = dweight_.data<T>();
T *bias_ptr = nullptr;
if (has_bias_) {
bias_ptr = dbias_.data<T>();
}
auto qkv_compute = paddle::operators::FeedForward<T>(
*ctx_, bsz_seq_, output_size_, input_size_, has_bias_);
qkv_compute.ComputeBackward(p_src, p_weight, p_doutput, p_dinput, p_dweight,
bias_ptr);
ctx_->Wait();
}
void Run() {
SetUp();
BaselineForward();
FusedForward();
BaselineBackward();
FusedBackward();
}
// check forward correctness between baseline and results of feedforward.
void CheckOut(const T diff, bool is_relative_atol = false) {
std::vector<T> out(size_output_);
std::vector<T> bias_out(size_output_);
TensorToVector(out_, *ctx_, &out);
if (has_bias_) {
TensorToVector(bias_out_, *ctx_, &bias_out);
}
ctx_->Wait();
for (int i = 0; i < size_output_; i++) {
if (is_relative_atol) {
EXPECT_LT(std::abs((out[i] - base_out_vec_[i]) / base_out_vec_[i]),
diff);
} else {
EXPECT_LT(std::abs(out[i] - base_out_vec_[i]), diff);
}
if (has_bias_) {
if (is_relative_atol) {
EXPECT_LT(std::abs((bias_out[i] - base_bias_out_vec_[i]) /
base_bias_out_vec_[i]),
diff);
} else {
EXPECT_LT(std::abs(bias_out[i] - base_bias_out_vec_[i]), diff);
}
}
}
}
// check backward correctness between baseline and results of feedforward.
void CheckGrad(const T diff, bool is_relative_atol = false) {
std::vector<T> h_dinput(size_src_);
TensorToVector(dinput_, *ctx_, &h_dinput);
for (int i = 0; i < size_src_; i++) {
if (is_relative_atol) {
EXPECT_LT(
std::abs((h_dinput[i] - base_dinput_vec_[i]) / base_dinput_vec_[i]),
diff);
} else {
EXPECT_LT(std::abs(h_dinput[i] - base_dinput_vec_[i]), diff);
}
}
std::vector<T> h_dweight(size_weight_);
TensorToVector(dweight_, *ctx_, &h_dweight);
for (int i = 0; i < size_weight_; i++) {
if (is_relative_atol) {
EXPECT_LT(std::abs((h_dweight[i] - base_dweight_vec_[i]) /
base_dweight_vec_[i]),
diff);
} else {
EXPECT_LT(std::abs(h_dweight[i] - base_dweight_vec_[i]), diff);
}
}
if (has_bias_) {
std::vector<T> h_dbias(size_bias_);
TensorToVector(dbias_, *ctx_, &h_dbias);
for (int i = 0; i < size_bias_; i++) {
if (is_relative_atol) {
EXPECT_LT(
std::abs((h_dbias[i] - base_dbias_vec_[i]) / base_dbias_vec_[i]),
diff);
} else {
EXPECT_LT(std::abs(h_dbias[i] - base_dbias_vec_[i]), diff);
}
}
}
}
private:
int batch_size_, seq_len_, num_head_, dim_head_, dim_embed_;
int bsz_seq_, output_size_, input_size_;
bool has_bias_;
int size_src_, size_weight_, size_bias_, size_output_;
framework::Tensor src_, weight_, bias_, out_, bias_out_;
framework::Tensor dinput_, dweight_, dbias_, doutput_;
std::vector<T> src_vec_, weight_vec_, bias_vec_, out_vec_, bias_out_vec_;
std::vector<T> dinput_vec_, dweight_vec_, dbias_vec_, doutput_vec_;
// results of baseline.
std::vector<T> base_out_vec_, base_bias_out_vec_;
std::vector<T> base_dinput_vec_, base_dweight_vec_, base_dbias_vec_;
platform::CUDAPlace place_;
platform::CUDADeviceContext *ctx_;
};
// test for fp32, fp16, fp32+bias and fp16+bias
TEST(FeedForward, GPUFeedforwardBertLargeSizeFp32) {
int batch_size = 16;
int seq_len = 128;
int num_head = 16;
int dim_head = 64;
int dim_embed = 1024;
bool has_bias = false;
TestFeedForward<float> test(batch_size, seq_len, num_head, dim_head,
dim_embed, has_bias);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5));
}
TEST(FeedForward, GPUFeedforwardBertLargeSizeFp16) {
int batch_size = 16;
int seq_len = 128;
int num_head = 16;
int dim_head = 64;
int dim_embed = 1024;
bool has_bias = false;
TestFeedForward<paddle::platform::float16> test(
batch_size, seq_len, num_head, dim_head, dim_embed, has_bias);
test.Run();
test.CheckOut(static_cast<paddle::platform::float16>(1e-5));
test.CheckGrad(static_cast<paddle::platform::float16>(1e-5));
}
TEST(FeedForward, GPUFeedforwardBertLargeSizeFp32Bias) {
int batch_size = 16;
int seq_len = 128;
int num_head = 16;
int dim_head = 64;
int dim_embed = 1024;
bool has_bias = true;
TestFeedForward<float> test(batch_size, seq_len, num_head, dim_head,
dim_embed, has_bias);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-3));
}
TEST(FeedForward, GPUFeedforwardBertLargeSizeFp16Bias) {
int batch_size = 16;
int seq_len = 128;
int num_head = 16;
int dim_head = 64;
int dim_embed = 1024;
bool has_bias = true;
TestFeedForward<paddle::platform::float16> test(
batch_size, seq_len, num_head, dim_head, dim_embed, has_bias);
test.Run();
test.CheckOut(static_cast<paddle::platform::float16>(1e-2));
test.CheckGrad(static_cast<paddle::platform::float16>(1e-2), true);
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/platform/fast_divmod.h"
namespace paddle {
namespace operators {
#define MAX_INPUT_NUM 2
namespace kps = paddle::operators::kernel_primitives;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using ReduceParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T>
struct CudaAddFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] + args[1];
}
};
template <typename InT, typename OutT, int ShapeSize, int VecSize,
int DATA_PER_THREAD, typename Functor>
__global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid;
InT arg0[VecSize * DATA_PER_THREAD];
InT arg1[VecSize * DATA_PER_THREAD];
OutT result[VecSize * DATA_PER_THREAD];
if (blockIdx.x < main_tid) {
num = blockDim.x * VecSize; // blockIdx.x < main_tid
}
// load in0
if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
arg0, in0, fix, configlists[0], numel, 1, 1);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
}
// load in1
if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
arg1, in1, fix, configlists[1], numel, 1, 1);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
}
// compute
kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, arg0, arg1, func);
// store
kernel_primitives::WriteData<OutT, VecSize, 1, 1>(out + fix, result, num);
}
template <typename T>
int GetVectorizedSizeImpl(const T* pointer) {
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec8 =
std::alignment_of<platform::CudaAlignedVector<T, 8>>::value; // NOLINT
constexpr int vec4 =
std::alignment_of<platform::CudaAlignedVector<T, 4>>::value; // NOLINT
constexpr int vec2 =
std::alignment_of<platform::CudaAlignedVector<T, 2>>::value; // NOLINT
if (address % vec8 == 0) {
// Note: this line can change from 4 to 8 if it can improve the performance.
return std::min(4, valid_vec_size);
} else if (address % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0) {
return std::min(2, valid_vec_size);
} else {
return 1;
}
}
// bias add forward impl for "[m, n] + [n] = [m, n]"
template <typename T>
void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
const T* in0, const T* in1, T* out) {
int in_vec_size =
std::min(GetVectorizedSizeImpl<T>(in0), GetVectorizedSizeImpl<T>(in1));
int out_vec_size = std::min(4, GetVectorizedSizeImpl<T>(out));
int vec_size = std::min(out_vec_size, in_vec_size);
int numel = m * n;
const int threads = 256;
const int data_per_thread = 1;
int blocks =
((numel + vec_size * data_per_thread - 1) / (vec_size * data_per_thread) +
threads - 1) /
threads;
int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads);
framework::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists;
framework::Array<bool, MAX_INPUT_NUM> use_broadcast;
use_broadcast[0] = false;
use_broadcast[1] = false;
if (m != 1) {
use_broadcast[1] = true;
}
// Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
auto func = CudaAddFunctor<T>();
auto stream = ctx.stream();
switch (vec_size) {
case 4: {
BroadcastKernelBinary<T, T, 2, 4,
data_per_thread><<<blocks, threads, 0, stream>>>(
in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid,
func);
break;
}
case 2: {
BroadcastKernelBinary<T, T, 2, 2,
data_per_thread><<<blocks, threads, 0, stream>>>(
in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid,
func);
break;
}
case 1: {
BroadcastKernelBinary<T, T, 2, 1,
data_per_thread><<<blocks, threads, 0, stream>>>(
in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid,
func);
break;
}
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
}
template <typename T, int BlockDim>
__global__ void LAUNCH_BOUNDS(BlockDim)
Compute1DColumnReduceKernel(const int reduce_num, const int left_num,
const T* in, T* out) {
typedef cub::BlockReduce<ReduceParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage mean_storage;
for (int i = blockIdx.x; i < left_num; i += gridDim.x) {
ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
for (int j = threadIdx.x; j < reduce_num; j += blockDim.x) {
const int index = j * left_num + i;
ReduceParamType<T> x_i = static_cast<ReduceParamType<T>>(in[index]);
x_sum += x_i;
}
x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
if (threadIdx.x == 0) {
out[i] = static_cast<T>(x_sum);
}
}
}
template <typename T>
void Launch1DColumnReduce(gpuStream_t stream, const int max_threads,
const int reduce_num, const int left_num,
const T* d_out, T* d_bias) {
const int block = 256;
const int max_blocks = std::max(max_threads / block, 1);
const int grid = std::min(left_num, max_blocks);
Compute1DColumnReduceKernel<T, block><<<grid, block, 0, stream>>>(
reduce_num, left_num, d_out, d_bias);
}
void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
const int left_num, int* blocking_size,
bool* should_reduce_again, dim3* block_dim,
dim3* grid_dim) {
block_dim->z = 1;
grid_dim->z = 1;
*should_reduce_again = false;
int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
*blocking_size = detail::GetLastPow2(reduce_num / num_block);
if (*blocking_size <= 1) {
*blocking_size = detail::GetLastPow2(sqrt(reduce_num));
} else if (*blocking_size * 2 < reduce_num) {
*blocking_size *= 2;
}
*should_reduce_again = true;
block_dim->x = 32;
block_dim->y = 1;
grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x;
grid_dim->y = (reduce_num + *blocking_size - 1) / *blocking_size;
} else {
block_dim->x = 32;
*blocking_size = reduce_num;
grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x;
grid_dim->y = 1;
}
}
template <typename T>
__global__ void BiasAddBwSinglePassKernel(const T* in, int reduce_num,
int left_num, T* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
if (idx < left_num) {
for (int iy = 0; iy < reduce_num; iy++) {
int id = iy * left_num + idx;
ReduceParamType<T> x_val = static_cast<ReduceParamType<T>>(in[id]);
x_sum += x_val;
}
out[idx] = static_cast<T>(x_sum);
}
}
template <typename T>
__global__ void BiasAddBw2DReduceKernel(const T* x, int reduce_num,
int left_num, int workload_per_thread,
ReduceParamType<T>* temp_x_sum) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int idy = blockIdx.y * workload_per_thread;
T x_val;
ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
if (idx < left_num) {
int loop = reduce_num - idy;
loop = loop > workload_per_thread ? workload_per_thread : loop;
for (int iy = 0; iy < loop; iy++) {
int id = (idy + iy) * left_num + idx;
ReduceParamType<T> x_val = static_cast<ReduceParamType<T>>(x[id]);
x_sum += x_val;
}
temp_x_sum[idx + blockIdx.y * left_num] = x_sum;
}
}
template <typename T>
__global__ void BiasAddBw1DReduceKernel(const ReduceParamType<T>* temp_sum,
int workload_per_thread, int left_num,
T* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
ReduceParamType<T> x_sum = static_cast<ReduceParamType<T>>(0);
if (idx < left_num) {
for (int iy = 0; iy < workload_per_thread; iy++) {
int id = iy * left_num + idx;
x_sum += temp_sum[id];
}
out[idx] = static_cast<T>(x_sum);
}
}
template <typename T>
void Launch2DColumnReduce(gpuStream_t stream, const int max_threads,
const int reduce_num, const int left_num,
const T* d_out, T* d_bias) {
dim3 block;
dim3 grid;
bool should_reduce_again = false;
int blocking_size = 1;
SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size,
&should_reduce_again, &block, &grid);
if (!should_reduce_again) {
BiasAddBwSinglePassKernel<T><<<grid, block, 0, stream>>>(d_out, reduce_num,
left_num, d_bias);
} else {
framework::Tensor tmp_sum;
tmp_sum.mutable_data<ReduceParamType<T>>(
framework::make_ddim({static_cast<int64_t>(
left_num * grid.y * sizeof(ReduceParamType<T>))}),
paddle::platform::CUDAPlace());
BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>(
d_out, reduce_num, left_num, blocking_size,
tmp_sum.template data<ReduceParamType<T>>());
BiasAddBw1DReduceKernel<T><<<grid.x, block.x, 0, stream>>>(
tmp_sum.template data<ReduceParamType<T>>(), grid.y, left_num, d_bias);
}
}
// bias add backward impl whose pattern are column-reduce with d_out[m, n] as
// input
// and d_bias[n] as output.
template <typename T>
void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m,
int n, const T* d_out, T* d_bias) {
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
int reduce_num = m;
int left_num = n;
bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
(left_num > REDUCE_SPLIT_BOUNDARY);
if (!is_large_enough) {
Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
d_out, d_bias);
} else {
Launch2DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
d_out, d_bias);
}
}
#undef MAX_INPUT_NUM
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
class FeedForward {
public:
FeedForward(const platform::CUDADeviceContext& dev_ctx, int bsz_seq,
int output_size, int input_size, bool compute_bias)
: dev_ctx_(dev_ctx),
bsz_seq_(bsz_seq),
output_size_(output_size),
input_size_(input_size),
compute_bias_(compute_bias) {}
~FeedForward() {}
void ComputeForward(const T* weight_data, const T* input_data,
const T* bias_data, T* output_data, T* bias_out_data) {
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// To convert to col-major expression, transa<->transb, A<->B,m<->n.
// column-major: gemm-tn.
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
// column-major: (m,n,k) = output_size,bsz_seq,input_size (weight*input=out)
// here: (m,n,k) = bsz_seq,output_size,input_size (input*weight=out)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha,
input_data, weight_data, beta, output_data);
if (compute_bias_) {
LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data,
bias_data, bias_out_data);
}
}
void ComputeBackward(T* input, T* weight, T* d_output, T* d_input,
T* d_weight, T* d_bias) {
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
// column-major: gemm-nt, get d_weight.
CBLAS_TRANSPOSE transA = CblasTrans;
CBLAS_TRANSPOSE transB = CblasNoTrans;
// column-major: (m,n,k): input_size,output_size,bsz (input*dout=dweight)
// here: (m,n,k): output_size,input_size,bsz (dout*input=dweight)
blas.GEMM(transA, transB, output_size_, input_size_, bsz_seq_, alpha,
d_output, input, beta, d_weight);
// column-major: gemm-nn: get d_input.
transA = CblasNoTrans;
// column-major: (m,n,k): input_size,bsz,output_size (weight*dout=dinput)
// here: (m, n, k): bsz, input_size, output_size, (dout*weight=dinput)
blas.GEMM(transA, transB, bsz_seq_, input_size_, output_size_, alpha,
d_output, weight, beta, d_input);
if (compute_bias_) {
LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias);
}
}
private:
const platform::CUDADeviceContext& dev_ctx_;
int bsz_seq_, output_size_, input_size_;
bool compute_bias_;
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册