From d1a33bc7b59033f2a08711356eb9e8a2e3642ae3 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Thu, 26 Aug 2021 16:41:54 +0800 Subject: [PATCH] Add feed_forward for fused attention op. (#34945) Describe Add feed_forward for fused attention op. (1) Encapsulate matmul impl (forward and backward) used in attention op. (2) Implement bias_add (forward and backward) used in attention op. --- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/feed_forward_test.cu | 550 ++++++++++++++++++ .../fluid/operators/fused/attn_bias_add.cu.h | 351 +++++++++++ .../fluid/operators/fused/attn_feed_forward.h | 91 +++ 4 files changed, 993 insertions(+) create mode 100644 paddle/fluid/operators/feed_forward_test.cu create mode 100644 paddle/fluid/operators/fused/attn_bias_add.cu.h create mode 100644 paddle/fluid/operators/fused/attn_feed_forward.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ff232b7ea5..4a82f558ff 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -157,6 +157,7 @@ cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_co if (WITH_GPU) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator) nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3) + nv_test(feed_forward_test SRCS feed_forward_test.cu DEPS elementwise_add_op matmul_op tensor generator) elseif(WITH_ROCM) hip_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator) hip_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3) diff --git a/paddle/fluid/operators/feed_forward_test.cu b/paddle/fluid/operators/feed_forward_test.cu new file mode 100644 index 0000000000..fe631500a3 --- /dev/null +++ b/paddle/fluid/operators/feed_forward_test.cu @@ -0,0 +1,550 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/attn_feed_forward.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/float16.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; + +USE_OP(matmul); +USE_OP(elementwise_add); + +// get paddle matmul op results as baseline +template +void GetLinearOp(const std::vector &x, const std::vector &y, + const framework::DDim &x_dim, const framework::DDim &y_dim, + const platform::CUDADeviceContext &ctx, bool transpose_a, + bool transpose_b, float alpha, std::vector *out) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + auto var_y = scope.Var("Y"); + auto tensor_y = var_y->GetMutable(); + auto var_out = scope.Var("Out"); + auto tensor_out = var_out->GetMutable(); + + tensor_x->Resize(x_dim); + tensor_y->Resize(y_dim); + tensor_out->Resize({x_dim[0], x_dim[1], y_dim[0]}); + + auto x_ptr = tensor_x->mutable_data(ctx.GetPlace()); + auto y_ptr = tensor_y->mutable_data(ctx.GetPlace()); + auto z_ptr = tensor_out->mutable_data(ctx.GetPlace()); + auto size_x = static_cast(framework::product(x_dim)); + auto size_y = static_cast(framework::product(y_dim)); + auto size_z = x_dim[0] * x_dim[1] * y_dim[0]; + cudaMemcpy(x_ptr, x.data(), size_x * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(y_ptr, y.data(), size_y * sizeof(T), cudaMemcpyHostToDevice); + + framework::AttributeMap attrs; + attrs.insert({"transpose_X", transpose_a}); + attrs.insert({"transpose_Y", transpose_b}); + attrs.insert({"alpha", alpha}); + + auto op = framework::OpRegistry::CreateOp( + "matmul", {{"X", {"X"}}, {"Y", {"Y"}}}, {{"Out", {"Out"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + cudaMemcpy(out->data(), z_ptr, size_z * sizeof(T), cudaMemcpyDeviceToHost); + ctx.Wait(); +} + +// get paddle elementwise_add op results as baseline +template +void GetElementwiseAddOp(const std::vector &x, const std::vector &y, + const int bsz_seq, const int output_size, + const platform::CUDADeviceContext &ctx, + std::vector *out) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + auto var_y = scope.Var("Y"); + auto tensor_y = var_y->GetMutable(); + auto var_out = scope.Var("Out"); + auto tensor_out = var_out->GetMutable(); + + tensor_x->Resize({bsz_seq, output_size}); + tensor_y->Resize({output_size}); + tensor_out->Resize({bsz_seq, output_size}); + + auto x_ptr = tensor_x->mutable_data(ctx.GetPlace()); + auto y_ptr = tensor_y->mutable_data(ctx.GetPlace()); + auto z_ptr = tensor_out->mutable_data(ctx.GetPlace()); + auto size_x = bsz_seq * output_size; + auto size_y = output_size; + auto size_z = bsz_seq * output_size; + cudaMemcpy(x_ptr, x.data(), size_x * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(y_ptr, y.data(), size_y * sizeof(T), cudaMemcpyHostToDevice); + + framework::AttributeMap attrs; + auto op = framework::OpRegistry::CreateOp("elementwise_add", + {{"X", {"X"}}, {"Y", {"Y"}}}, + {{"Out", {"Out"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + cudaMemcpy(out->data(), z_ptr, size_z * sizeof(T), cudaMemcpyDeviceToHost); + ctx.Wait(); +} + +// get paddle matmul_grad op results as baseline +template +void GetLinearOpGrad(const std::vector &x_vec, const std::vector &y_vec, + const std::vector &dout_vec, + const framework::DDim &x_dim, const framework::DDim &y_dim, + const framework::DDim &out_dim, + const platform::CUDADeviceContext &ctx, bool transpose_a, + bool transpose_b, float alpha, std::vector *dinput_vec, + std::vector *dweight_vec) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + auto var_y = scope.Var("Y"); + auto tensor_y = var_y->GetMutable(); + auto var_dout = scope.Var("DOut"); + auto tensor_dout = var_dout->GetMutable(); + tensor_x->Resize(x_dim); + tensor_y->Resize(y_dim); + tensor_dout->Resize(out_dim); + + auto var_dx = scope.Var("DX"); + auto tensor_dx = var_dx->GetMutable(); + auto var_dy = scope.Var("DY"); + auto tensor_dy = var_dy->GetMutable(); + tensor_dx->Resize(x_dim); + tensor_dy->Resize(y_dim); + + auto x_ptr = tensor_x->mutable_data(ctx.GetPlace()); + auto y_ptr = tensor_y->mutable_data(ctx.GetPlace()); + auto dout_ptr = tensor_dout->mutable_data(ctx.GetPlace()); + auto dinput_ptr = tensor_dx->mutable_data(ctx.GetPlace()); + auto dweight_ptr = tensor_dy->mutable_data(ctx.GetPlace()); + + auto size_x = static_cast(framework::product(x_dim)); + auto size_y = static_cast(framework::product(y_dim)); + auto size_z = x_dim[0] * x_dim[1] * y_dim[0]; + cudaMemcpy(x_ptr, x_vec.data(), size_x * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(y_ptr, y_vec.data(), size_y * sizeof(T), cudaMemcpyHostToDevice); + cudaMemcpy(dout_ptr, dout_vec.data(), size_z * sizeof(T), + cudaMemcpyHostToDevice); + + bool use_mkldnn = false; + std::vector fused_reshape_X = {}; + std::vector fused_reshape_Y = {}; + std::vector fused_reshape_Out = {}; + std::vector fused_transpose_X = {}; + std::vector fused_transpose_Y = {}; + std::vector fused_transpose_Out = {}; + bool use_quantizer = false, force_fp32_output = false; + std::string mkldnn_data_type = "float32"; + float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0; + + framework::AttributeMap attrs; + attrs.insert({"transpose_X", transpose_a}); + attrs.insert({"transpose_Y", transpose_b}); + attrs.insert({"alpha", alpha}); + attrs.insert({"use_mkldnn", use_mkldnn}); + attrs.insert({"fused_reshape_X", fused_reshape_X}); + attrs.insert({"fused_reshape_Y", fused_reshape_Y}); + attrs.insert({"fused_reshape_Out", fused_reshape_Out}); + attrs.insert({"fused_transpose_X", fused_transpose_X}); + attrs.insert({"fused_transpose_Y", fused_transpose_Y}); + attrs.insert({"fused_transpose_Out", fused_transpose_Out}); + attrs.insert({"use_quantizer", use_quantizer}); + attrs.insert({"mkldnn_data_type", mkldnn_data_type}); + attrs.insert({"Scale_x", Scale_x}); + attrs.insert({"Scale_y", Scale_y}); + attrs.insert({"Scale_out", Scale_out}); + attrs.insert({"force_fp32_output", force_fp32_output}); + + auto op = framework::OpRegistry::CreateOp( + "matmul_grad", {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}}, + {{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + cudaMemcpy(dinput_vec->data(), dinput_ptr, size_x * sizeof(T), + cudaMemcpyDeviceToHost); + cudaMemcpy(dweight_vec->data(), dweight_ptr, size_y * sizeof(T), + cudaMemcpyDeviceToHost); + ctx.Wait(); +} + +// get paddle elementwise_add_grad op results as baseline +template +void GetElementwiseAddOpGrad(const std::vector &dout_vec, const int bsz_seq, + const int output_size, + const platform::CUDADeviceContext &ctx, + std::vector *dy_vec) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + auto var_y = scope.Var("Y"); + auto tensor_y = var_y->GetMutable(); + auto var_dout = scope.Var("DOut"); + auto tensor_dout = var_dout->GetMutable(); + tensor_x->Resize({bsz_seq, output_size}); + tensor_y->Resize({output_size}); + tensor_dout->Resize({bsz_seq, output_size}); + + auto var_dx = scope.Var("DX"); + auto tensor_dx = var_dx->GetMutable(); + auto var_dy = scope.Var("DY"); + auto tensor_dy = var_dy->GetMutable(); + tensor_dx->Resize({bsz_seq, output_size}); + tensor_dy->Resize({output_size}); + + auto dout_ptr = tensor_dout->mutable_data(ctx.GetPlace()); + auto tensor_dy_ptr = tensor_dy->mutable_data(ctx.GetPlace()); + auto size_z = static_cast(bsz_seq * output_size); + cudaMemcpy(dout_ptr, dout_vec.data(), size_z * sizeof(T), + cudaMemcpyHostToDevice); + + int axis = -1; + bool use_mkldnn = false, use_quantizer = false; + std::string mkldnn_data_type = "float32"; + std::string x_data_format = "", y_data_format = ""; + float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0; + + framework::AttributeMap attrs; + attrs.insert({"axis", axis}); + attrs.insert({"use_mkldnn", use_mkldnn}); + attrs.insert({"x_data_format", x_data_format}); + attrs.insert({"y_data_format", y_data_format}); + attrs.insert({"use_quantizer", use_quantizer}); + attrs.insert({"mkldnn_data_type", mkldnn_data_type}); + attrs.insert({"Scale_x", Scale_x}); + attrs.insert({"Scale_y", Scale_y}); + attrs.insert({"Scale_out", Scale_out}); + + auto op = framework::OpRegistry::CreateOp( + "elementwise_add_grad", + {{"Out@GRAD", {"DOut"}}, {"X", {"X"}}, {"Y", {"Y"}}}, + {{"X@GRAD", {"DX"}}, {"Y@GRAD", {"DY"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + auto size_y = static_cast(output_size); + cudaMemcpy(dy_vec->data(), tensor_dy_ptr, size_y * sizeof(T), + cudaMemcpyDeviceToHost); + ctx.Wait(); +} + +template +class TestFeedForward { + public: + TestFeedForward() { + batch_size_ = 16; + seq_len_ = 128; + num_head_ = 16; + dim_head_ = 64; + dim_embed_ = 1024; + has_bias_ = false; + } + + TestFeedForward(int batch_size, int seq_len, int num_head, int dim_head, + int dim_embed, bool has_bias) { + batch_size_ = batch_size; + seq_len_ = seq_len; + num_head_ = num_head; + dim_head_ = dim_head; + dim_embed_ = dim_embed; + has_bias_ = has_bias; + } + + ~TestFeedForward() { delete ctx_; } + + void SetUp() { + bsz_seq_ = batch_size_ * seq_len_; + output_size_ = 3 * num_head_ * dim_head_; + input_size_ = dim_embed_; + ctx_ = new platform::CUDADeviceContext(place_); + + size_src_ = bsz_seq_ * dim_embed_; // src: [bs, seq_len, em_dim] + size_weight_ = dim_embed_ * output_size_; // weight: [output_size, em_dim] + size_output_ = + bsz_seq_ * output_size_; // output: [bs, seq_len, output_size] + size_bias_ = output_size_; + + base_out_vec_.resize(size_output_); + base_bias_out_vec_.resize(size_output_); + base_dinput_vec_.resize(size_src_); + base_dweight_vec_.resize(size_weight_); + base_dbias_vec_.resize(size_bias_); + + src_vec_.resize(size_src_); + weight_vec_.resize(size_weight_); + bias_vec_.resize(size_bias_); + doutput_vec_.resize(size_output_); + + std::default_random_engine random(time(NULL)); + std::uniform_real_distribution dis(0.0, 1.0); + for (int i = 0; i < size_src_; i++) { + src_vec_[i] = static_cast(dis(random)); + } + for (int i = 0; i < size_weight_; i++) { + weight_vec_[i] = static_cast(dis(random)); + } + for (int i = 0; i < size_bias_; i++) { + bias_vec_[i] = static_cast(dis(random)); + } + for (int i = 0; i < size_output_; i++) { + doutput_vec_[i] = static_cast(dis(random)); + } + + framework::TensorFromVector(src_vec_, *ctx_, &src_); + src_.Resize({batch_size_, seq_len_, dim_embed_}); + framework::TensorFromVector(weight_vec_, *ctx_, &weight_); + weight_.Resize({output_size_, dim_embed_}); + out_.Resize({batch_size_, seq_len_, output_size_}); + out_.mutable_data(place_); + if (has_bias_) { + framework::TensorFromVector(bias_vec_, *ctx_, &bias_); + bias_.Resize({output_size_}); + bias_out_.Resize({batch_size_, seq_len_, output_size_}); + bias_out_.mutable_data(place_); + } + framework::TensorFromVector(doutput_vec_, *ctx_, &doutput_); + doutput_.Resize({batch_size_, seq_len_, output_size_}); + + dinput_.Resize({batch_size_, seq_len_, dim_embed_}); + dinput_.mutable_data(place_); + dweight_.Resize({output_size_, dim_embed_}); + dweight_.mutable_data(place_); + if (has_bias_) { + dbias_.Resize({output_size_}); + dbias_.mutable_data(place_); + } + } + + void BaselineForward() { + bool transpose_a = false, transpose_b = true; + float alpha = 1; + GetLinearOp(src_vec_, weight_vec_, src_.dims(), weight_.dims(), *ctx_, + transpose_a, transpose_b, alpha, &base_out_vec_); + if (has_bias_) { + GetElementwiseAddOp(base_out_vec_, bias_vec_, bsz_seq_, output_size_, + *ctx_, &base_bias_out_vec_); + } + ctx_->Wait(); + } + + // get forward results of feedforward. + void FusedForward() { + T *p_weight = weight_.data(); + T *p_src = src_.data(); + T *p_output = out_.data(); + + T *p_bias = nullptr; + T *p_bias_output = nullptr; + if (has_bias_) { + p_bias = bias_.data(); + p_bias_output = bias_out_.data(); + } + auto qkv_compute = paddle::operators::FeedForward( + *ctx_, bsz_seq_, output_size_, input_size_, has_bias_); + qkv_compute.ComputeForward(p_weight, p_src, p_bias, p_output, + p_bias_output); + ctx_->Wait(); + } + + void BaselineBackward() { + bool transpose_a = false, transpose_b = true; + float alpha = 1; + + GetLinearOpGrad(src_vec_, weight_vec_, doutput_vec_, src_.dims(), + weight_.dims(), out_.dims(), *ctx_, transpose_a, + transpose_b, alpha, &base_dinput_vec_, &base_dweight_vec_); + if (has_bias_) { + GetElementwiseAddOpGrad(doutput_vec_, bsz_seq_, output_size_, *ctx_, + &base_dbias_vec_); + } + ctx_->Wait(); + } + + // get backward results of feedforward. + void FusedBackward() { + T *p_weight = weight_.data(); + T *p_src = src_.data(); + T *p_doutput = doutput_.data(); + T *p_dinput = dinput_.data(); + T *p_dweight = dweight_.data(); + + T *bias_ptr = nullptr; + if (has_bias_) { + bias_ptr = dbias_.data(); + } + auto qkv_compute = paddle::operators::FeedForward( + *ctx_, bsz_seq_, output_size_, input_size_, has_bias_); + qkv_compute.ComputeBackward(p_src, p_weight, p_doutput, p_dinput, p_dweight, + bias_ptr); + ctx_->Wait(); + } + + void Run() { + SetUp(); + BaselineForward(); + FusedForward(); + BaselineBackward(); + FusedBackward(); + } + + // check forward correctness between baseline and results of feedforward. + void CheckOut(const T diff, bool is_relative_atol = false) { + std::vector out(size_output_); + std::vector bias_out(size_output_); + TensorToVector(out_, *ctx_, &out); + if (has_bias_) { + TensorToVector(bias_out_, *ctx_, &bias_out); + } + ctx_->Wait(); + + for (int i = 0; i < size_output_; i++) { + if (is_relative_atol) { + EXPECT_LT(std::abs((out[i] - base_out_vec_[i]) / base_out_vec_[i]), + diff); + } else { + EXPECT_LT(std::abs(out[i] - base_out_vec_[i]), diff); + } + if (has_bias_) { + if (is_relative_atol) { + EXPECT_LT(std::abs((bias_out[i] - base_bias_out_vec_[i]) / + base_bias_out_vec_[i]), + diff); + } else { + EXPECT_LT(std::abs(bias_out[i] - base_bias_out_vec_[i]), diff); + } + } + } + } + + // check backward correctness between baseline and results of feedforward. + void CheckGrad(const T diff, bool is_relative_atol = false) { + std::vector h_dinput(size_src_); + TensorToVector(dinput_, *ctx_, &h_dinput); + for (int i = 0; i < size_src_; i++) { + if (is_relative_atol) { + EXPECT_LT( + std::abs((h_dinput[i] - base_dinput_vec_[i]) / base_dinput_vec_[i]), + diff); + } else { + EXPECT_LT(std::abs(h_dinput[i] - base_dinput_vec_[i]), diff); + } + } + std::vector h_dweight(size_weight_); + TensorToVector(dweight_, *ctx_, &h_dweight); + for (int i = 0; i < size_weight_; i++) { + if (is_relative_atol) { + EXPECT_LT(std::abs((h_dweight[i] - base_dweight_vec_[i]) / + base_dweight_vec_[i]), + diff); + } else { + EXPECT_LT(std::abs(h_dweight[i] - base_dweight_vec_[i]), diff); + } + } + if (has_bias_) { + std::vector h_dbias(size_bias_); + TensorToVector(dbias_, *ctx_, &h_dbias); + for (int i = 0; i < size_bias_; i++) { + if (is_relative_atol) { + EXPECT_LT( + std::abs((h_dbias[i] - base_dbias_vec_[i]) / base_dbias_vec_[i]), + diff); + } else { + EXPECT_LT(std::abs(h_dbias[i] - base_dbias_vec_[i]), diff); + } + } + } + } + + private: + int batch_size_, seq_len_, num_head_, dim_head_, dim_embed_; + int bsz_seq_, output_size_, input_size_; + bool has_bias_; + int size_src_, size_weight_, size_bias_, size_output_; + + framework::Tensor src_, weight_, bias_, out_, bias_out_; + framework::Tensor dinput_, dweight_, dbias_, doutput_; + std::vector src_vec_, weight_vec_, bias_vec_, out_vec_, bias_out_vec_; + std::vector dinput_vec_, dweight_vec_, dbias_vec_, doutput_vec_; + + // results of baseline. + std::vector base_out_vec_, base_bias_out_vec_; + std::vector base_dinput_vec_, base_dweight_vec_, base_dbias_vec_; + + platform::CUDAPlace place_; + platform::CUDADeviceContext *ctx_; +}; + +// test for fp32, fp16, fp32+bias and fp16+bias +TEST(FeedForward, GPUFeedforwardBertLargeSizeFp32) { + int batch_size = 16; + int seq_len = 128; + int num_head = 16; + int dim_head = 64; + int dim_embed = 1024; + bool has_bias = false; + TestFeedForward test(batch_size, seq_len, num_head, dim_head, + dim_embed, has_bias); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FeedForward, GPUFeedforwardBertLargeSizeFp16) { + int batch_size = 16; + int seq_len = 128; + int num_head = 16; + int dim_head = 64; + int dim_embed = 1024; + bool has_bias = false; + TestFeedForward test( + batch_size, seq_len, num_head, dim_head, dim_embed, has_bias); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FeedForward, GPUFeedforwardBertLargeSizeFp32Bias) { + int batch_size = 16; + int seq_len = 128; + int num_head = 16; + int dim_head = 64; + int dim_embed = 1024; + bool has_bias = true; + TestFeedForward test(batch_size, seq_len, num_head, dim_head, + dim_embed, has_bias); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); +} + +TEST(FeedForward, GPUFeedforwardBertLargeSizeFp16Bias) { + int batch_size = 16; + int seq_len = 128; + int num_head = 16; + int dim_head = 64; + int dim_embed = 1024; + bool has_bias = true; + TestFeedForward test( + batch_size, seq_len, num_head, dim_head, dim_embed, has_bias); + test.Run(); + test.CheckOut(static_cast(1e-2)); + test.CheckGrad(static_cast(1e-2), true); +} diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h new file mode 100644 index 0000000000..2e98a7f332 --- /dev/null +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -0,0 +1,351 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#else +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +#ifdef __HIPCC__ +#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim) +#else +#define LAUNCH_BOUNDS(BlockDim) +#endif + +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/platform/fast_divmod.h" + +namespace paddle { +namespace operators { + +#define MAX_INPUT_NUM 2 + +namespace kps = paddle::operators::kernel_primitives; + +template +using CudnnDataType = platform::CudnnDataType; +template +using ReduceParamType = typename CudnnDataType::BatchNormParamType; + +template +struct CudaAddFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return args[0] + args[1]; + } +}; + +template +__global__ void BroadcastKernelBinary( + const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out, + framework::Array use_broadcast, uint32_t numel, + framework::Array, MAX_INPUT_NUM> + configlists, + int main_tid, int tail_tid, Functor func) { + int fix = blockIdx.x * blockDim.x * VecSize; + int num = tail_tid; + InT arg0[VecSize * DATA_PER_THREAD]; + InT arg1[VecSize * DATA_PER_THREAD]; + OutT result[VecSize * DATA_PER_THREAD]; + if (blockIdx.x < main_tid) { + num = blockDim.x * VecSize; // blockIdx.x < main_tid + } + + // load in0 + if (use_broadcast[0]) { + kernel_primitives::ReadDataBc( + arg0, in0, fix, configlists[0], numel, 1, 1); + } else { + kernel_primitives::ReadData(arg0, in0 + fix, num); + } + // load in1 + if (use_broadcast[1]) { + kernel_primitives::ReadDataBc( + arg1, in1, fix, configlists[1], numel, 1, 1); + } else { + kernel_primitives::ReadData(arg1, in1 + fix, num); + } + // compute + kernel_primitives::ElementwiseBinary( + result, arg0, arg1, func); + // store + kernel_primitives::WriteData(out + fix, result, num); +} + +template +int GetVectorizedSizeImpl(const T* pointer) { + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); + uint64_t address = reinterpret_cast(pointer); + constexpr int vec8 = + std::alignment_of>::value; // NOLINT + constexpr int vec4 = + std::alignment_of>::value; // NOLINT + constexpr int vec2 = + std::alignment_of>::value; // NOLINT + if (address % vec8 == 0) { + // Note: this line can change from 4 to 8 if it can improve the performance. + return std::min(4, valid_vec_size); + } else if (address % vec4 == 0) { + return std::min(4, valid_vec_size); + } else if (address % vec2 == 0) { + return std::min(2, valid_vec_size); + } else { + return 1; + } +} + +// bias add forward impl for "[m, n] + [n] = [m, n]" +template +void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, + const T* in0, const T* in1, T* out) { + int in_vec_size = + std::min(GetVectorizedSizeImpl(in0), GetVectorizedSizeImpl(in1)); + int out_vec_size = std::min(4, GetVectorizedSizeImpl(out)); + int vec_size = std::min(out_vec_size, in_vec_size); + + int numel = m * n; + const int threads = 256; + const int data_per_thread = 1; + int blocks = + ((numel + vec_size * data_per_thread - 1) / (vec_size * data_per_thread) + + threads - 1) / + threads; + int main_tid = numel / (data_per_thread * vec_size * threads); + int tail_tid = numel % (data_per_thread * vec_size * threads); + + framework::Array, MAX_INPUT_NUM> configlists; + framework::Array use_broadcast; + + use_broadcast[0] = false; + use_broadcast[1] = false; + if (m != 1) { + use_broadcast[1] = true; + } + // Here, dims are transposed due to the logic in BroadcastConfig. + std::vector input1_dims = {n, 1}; + std::vector out_dims = {n, m}; + configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2); + + auto func = CudaAddFunctor(); + auto stream = ctx.stream(); + switch (vec_size) { + case 4: { + BroadcastKernelBinary<<>>( + in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid, + func); + break; + } + case 2: { + BroadcastKernelBinary<<>>( + in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid, + func); + break; + } + case 1: { + BroadcastKernelBinary<<>>( + in0, in1, out, use_broadcast, numel, configlists, main_tid, tail_tid, + func); + break; + } + default: { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +} + +template +__global__ void LAUNCH_BOUNDS(BlockDim) + Compute1DColumnReduceKernel(const int reduce_num, const int left_num, + const T* in, T* out) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage mean_storage; + + for (int i = blockIdx.x; i < left_num; i += gridDim.x) { + ReduceParamType x_sum = static_cast>(0); + for (int j = threadIdx.x; j < reduce_num; j += blockDim.x) { + const int index = j * left_num + i; + ReduceParamType x_i = static_cast>(in[index]); + x_sum += x_i; + } + x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum()); + if (threadIdx.x == 0) { + out[i] = static_cast(x_sum); + } + } +} + +template +void Launch1DColumnReduce(gpuStream_t stream, const int max_threads, + const int reduce_num, const int left_num, + const T* d_out, T* d_bias) { + const int block = 256; + const int max_blocks = std::max(max_threads / block, 1); + const int grid = std::min(left_num, max_blocks); + Compute1DColumnReduceKernel<<>>( + reduce_num, left_num, d_out, d_bias); +} + +void SetConfigForColumnReduce(const int max_threads, const int reduce_num, + const int left_num, int* blocking_size, + bool* should_reduce_again, dim3* block_dim, + dim3* grid_dim) { + block_dim->z = 1; + grid_dim->z = 1; + *should_reduce_again = false; + + int num_block = (max_threads / left_num); + if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { + *blocking_size = detail::GetLastPow2(reduce_num / num_block); + if (*blocking_size <= 1) { + *blocking_size = detail::GetLastPow2(sqrt(reduce_num)); + } else if (*blocking_size * 2 < reduce_num) { + *blocking_size *= 2; + } + *should_reduce_again = true; + block_dim->x = 32; + block_dim->y = 1; + grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x; + grid_dim->y = (reduce_num + *blocking_size - 1) / *blocking_size; + } else { + block_dim->x = 32; + *blocking_size = reduce_num; + grid_dim->x = (left_num + block_dim->x - 1) / block_dim->x; + grid_dim->y = 1; + } +} + +template +__global__ void BiasAddBwSinglePassKernel(const T* in, int reduce_num, + int left_num, T* out) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + ReduceParamType x_sum = static_cast>(0); + if (idx < left_num) { + for (int iy = 0; iy < reduce_num; iy++) { + int id = iy * left_num + idx; + ReduceParamType x_val = static_cast>(in[id]); + x_sum += x_val; + } + out[idx] = static_cast(x_sum); + } +} + +template +__global__ void BiasAddBw2DReduceKernel(const T* x, int reduce_num, + int left_num, int workload_per_thread, + ReduceParamType* temp_x_sum) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int idy = blockIdx.y * workload_per_thread; + + T x_val; + ReduceParamType x_sum = static_cast>(0); + if (idx < left_num) { + int loop = reduce_num - idy; + loop = loop > workload_per_thread ? workload_per_thread : loop; + for (int iy = 0; iy < loop; iy++) { + int id = (idy + iy) * left_num + idx; + ReduceParamType x_val = static_cast>(x[id]); + x_sum += x_val; + } + temp_x_sum[idx + blockIdx.y * left_num] = x_sum; + } +} + +template +__global__ void BiasAddBw1DReduceKernel(const ReduceParamType* temp_sum, + int workload_per_thread, int left_num, + T* out) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + ReduceParamType x_sum = static_cast>(0); + if (idx < left_num) { + for (int iy = 0; iy < workload_per_thread; iy++) { + int id = iy * left_num + idx; + x_sum += temp_sum[id]; + } + out[idx] = static_cast(x_sum); + } +} + +template +void Launch2DColumnReduce(gpuStream_t stream, const int max_threads, + const int reduce_num, const int left_num, + const T* d_out, T* d_bias) { + dim3 block; + dim3 grid; + bool should_reduce_again = false; + int blocking_size = 1; + SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size, + &should_reduce_again, &block, &grid); + + if (!should_reduce_again) { + BiasAddBwSinglePassKernel<<>>(d_out, reduce_num, + left_num, d_bias); + } else { + framework::Tensor tmp_sum; + tmp_sum.mutable_data>( + framework::make_ddim({static_cast( + left_num * grid.y * sizeof(ReduceParamType))}), + paddle::platform::CUDAPlace()); + + BiasAddBw2DReduceKernel<<>>( + d_out, reduce_num, left_num, blocking_size, + tmp_sum.template data>()); + + BiasAddBw1DReduceKernel<<>>( + tmp_sum.template data>(), grid.y, left_num, d_bias); + } +} + +// bias add backward impl whose pattern are column-reduce with d_out[m, n] as +// input +// and d_bias[n] as output. +template +void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m, + int n, const T* d_out, T* d_bias) { + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + int reduce_num = m; + int left_num = n; + bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) || + (left_num > REDUCE_SPLIT_BOUNDARY); + if (!is_large_enough) { + Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, + d_out, d_bias); + } else { + Launch2DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, + d_out, d_bias); + } +} + +#undef MAX_INPUT_NUM + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/attn_feed_forward.h b/paddle/fluid/operators/fused/attn_feed_forward.h new file mode 100644 index 0000000000..e7eba2da63 --- /dev/null +++ b/paddle/fluid/operators/fused/attn_feed_forward.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/attn_bias_add.cu.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +class FeedForward { + public: + FeedForward(const platform::CUDADeviceContext& dev_ctx, int bsz_seq, + int output_size, int input_size, bool compute_bias) + : dev_ctx_(dev_ctx), + bsz_seq_(bsz_seq), + output_size_(output_size), + input_size_(input_size), + compute_bias_(compute_bias) {} + + ~FeedForward() {} + + void ComputeForward(const T* weight_data, const T* input_data, + const T* bias_data, T* output_data, T* bias_out_data) { + // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. + // To convert to col-major expression, transa<->transb, A<->B,m<->n. + + // column-major: gemm-tn. + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasTrans; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + + // column-major: (m,n,k) = output_size,bsz_seq,input_size (weight*input=out) + // here: (m,n,k) = bsz_seq,output_size,input_size (input*weight=out) + auto blas = math::GetBlas(dev_ctx_); + blas.GEMM(transA, transB, bsz_seq_, output_size_, input_size_, alpha, + input_data, weight_data, beta, output_data); + if (compute_bias_) { + LaunchBiasAddFwKernel(dev_ctx_, bsz_seq_, output_size_, output_data, + bias_data, bias_out_data); + } + } + + void ComputeBackward(T* input, T* weight, T* d_output, T* d_input, + T* d_weight, T* d_bias) { + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + auto blas = math::GetBlas(dev_ctx_); + + // column-major: gemm-nt, get d_weight. + CBLAS_TRANSPOSE transA = CblasTrans; + CBLAS_TRANSPOSE transB = CblasNoTrans; + // column-major: (m,n,k): input_size,output_size,bsz (input*dout=dweight) + // here: (m,n,k): output_size,input_size,bsz (dout*input=dweight) + blas.GEMM(transA, transB, output_size_, input_size_, bsz_seq_, alpha, + d_output, input, beta, d_weight); + + // column-major: gemm-nn: get d_input. + transA = CblasNoTrans; + // column-major: (m,n,k): input_size,bsz,output_size (weight*dout=dinput) + // here: (m, n, k): bsz, input_size, output_size, (dout*weight=dinput) + blas.GEMM(transA, transB, bsz_seq_, input_size_, output_size_, alpha, + d_output, weight, beta, d_input); + if (compute_bias_) { + LaunchBiasAddBwKernel(dev_ctx_, bsz_seq_, output_size_, d_output, d_bias); + } + } + + private: + const platform::CUDADeviceContext& dev_ctx_; + int bsz_seq_, output_size_, input_size_; + bool compute_bias_; +}; + +} // namespace operators +} // namespace paddle -- GitLab