未验证 提交 8910bb4a 编写于 作者: H Huang Jiyi 提交者: GitHub

[Phi decouple] move layer_norm_kernel.cu.h to phi (#50506)

* move layer_norm_kernel.cu.h to phi

* fix bugs

* fix namespace

* fix bugs

* fix CI-Windwos

* replace mutable_data

* fix bugs

* fix bugs
上级 c8aa6405
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
#include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h"
#include "paddle/fluid/operators/fused/fused_dropout_common.h" #include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace paddle { namespace paddle {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -35,11 +35,11 @@ class AttnLayerNorm { ...@@ -35,11 +35,11 @@ class AttnLayerNorm {
~AttnLayerNorm() {} ~AttnLayerNorm() {}
void ComputeForward(const InType* x_data, void ComputeForward(const InType* x_data,
const LayerNormParamType<T>* scale_data, const phi::funcs::LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* bias_data, const phi::funcs::LayerNormParamType<T>* bias_data,
OutType* y_data, OutType* y_data,
LayerNormParamType<T>* mean_data, phi::funcs::LayerNormParamType<T>* mean_data,
LayerNormParamType<T>* var_data, phi::funcs::LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr, const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0, const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0, const float quant_in_scale = 1.0,
...@@ -48,14 +48,14 @@ class AttnLayerNorm { ...@@ -48,14 +48,14 @@ class AttnLayerNorm {
const float quant_min_bound = -127.0) { const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream(); auto stream = dev_ctx_.stream();
switch (GetDesiredBlockDim(feature_size_)) { switch (phi::funcs::GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, phi::funcs::LayerNormForward<T,
LayerNormParamType<T>, phi::funcs::LayerNormParamType<T>,
kBlockDim, kBlockDim,
false, false,
InType, InType,
OutType> OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data, <<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data, scale_data,
bias_data, bias_data,
...@@ -71,32 +71,33 @@ class AttnLayerNorm { ...@@ -71,32 +71,33 @@ class AttnLayerNorm {
quant_max_bound, quant_max_bound,
quant_min_bound)); quant_min_bound));
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(
"Feature_size must be larger than 1")); phi::errors::InvalidArgument("Feature_size must be larger than 1"));
break; break;
} }
} }
void ComputeBackward(const T* x_data, void ComputeBackward(const T* x_data,
const T* d_y_data, const T* d_y_data,
const LayerNormParamType<T>* scale_data, const phi::funcs::LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* mean_data, const phi::funcs::LayerNormParamType<T>* mean_data,
const LayerNormParamType<T>* var_data, const phi::funcs::LayerNormParamType<T>* var_data,
T* d_x_data, T* d_x_data,
LayerNormParamType<T>* d_scale_data, phi::funcs::LayerNormParamType<T>* d_scale_data,
LayerNormParamType<T>* d_bias_data) { phi::funcs::LayerNormParamType<T>* d_bias_data) {
LayerNormBackward<T, LayerNormParamType<T>>(x_data, phi::funcs::LayerNormBackward<T, phi::funcs::LayerNormParamType<T>>(
d_y_data, x_data,
scale_data, d_y_data,
mean_data, scale_data,
var_data, mean_data,
d_x_data, var_data,
d_scale_data, d_x_data,
d_bias_data, d_scale_data,
epsilon_, d_bias_data,
batch_size_, epsilon_,
feature_size_, batch_size_,
dev_ctx_); feature_size_,
dev_ctx_);
} }
private: private:
......
...@@ -26,7 +26,7 @@ namespace operators { ...@@ -26,7 +26,7 @@ namespace operators {
template <typename T> template <typename T>
struct GeluFunctor { struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const { inline __host__ __device__ T operator()(const T x) const {
using U = LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
const U casted_x = static_cast<U>(x); const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2)); const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp)); const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
...@@ -47,7 +47,7 @@ struct FastGeluFunctor { ...@@ -47,7 +47,7 @@ struct FastGeluFunctor {
template <typename T> template <typename T>
struct GeluGradFunctor { struct GeluGradFunctor {
inline __host__ __device__ T UseOut(const T x) const { inline __host__ __device__ T UseOut(const T x) const {
using U = LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
auto casted_x = static_cast<U>(x); auto casted_x = static_cast<U>(x);
auto first = auto first =
......
...@@ -21,13 +21,13 @@ limitations under the License. */ ...@@ -21,13 +21,13 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -138,7 +138,7 @@ inline __device__ void CalculateDBias(const T *tmp_sum, ...@@ -138,7 +138,7 @@ inline __device__ void CalculateDBias(const T *tmp_sum,
int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32; int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32;
// reduce 32 to 1 // reduce 32 to 1
for (int i = 0; i < reduce_num_pre_thread; i++) { for (int i = 0; i < reduce_num_pre_thread; i++) {
sum[i] = WarpReduceSum(sum[i]); sum[i] = phi::funcs::WarpReduceSum(sum[i]);
} }
// save sum to dbias // save sum to dbias
......
...@@ -418,18 +418,18 @@ class FusedDropoutLayerNormHelper ...@@ -418,18 +418,18 @@ class FusedDropoutLayerNormHelper
LayerNormParamType<T>* d_scale, LayerNormParamType<T>* d_scale,
LayerNormParamType<T>* d_bias) { LayerNormParamType<T>* d_bias) {
using U = LayerNormParamType<T>; using U = LayerNormParamType<T>;
LayerNormBackward<T, U>(src, phi::funcs::LayerNormBackward<T, U>(src,
dout, dout,
gamma, gamma,
mean, mean,
variance, variance,
d_src, d_src,
d_scale, d_scale,
d_bias, d_bias,
epsilon_, epsilon_,
this->rows_, this->rows_,
this->cols_, this->cols_,
ctx); ctx);
} }
// out = layernorm(residual + dropout(src + bias)) // out = layernorm(residual + dropout(src + bias))
...@@ -457,7 +457,7 @@ class FusedDropoutLayerNormHelper ...@@ -457,7 +457,7 @@ class FusedDropoutLayerNormHelper
if (this->cols_ % vec_size != 0) { if (this->cols_ % vec_size != 0) {
vec_size = 1; vec_size = 1;
} }
int threads = GetDesiredBlockDim(this->cols_ / vec_size); int threads = phi::funcs::GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size; int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T, LaunchLayernormResidualDropoutBias<T,
...@@ -537,18 +537,18 @@ class FusedDropoutLayerNormHelper ...@@ -537,18 +537,18 @@ class FusedDropoutLayerNormHelper
d_residual, d_residual,
d_dropout_src); d_dropout_src);
} else { } else {
LayerNormBackward<T, U, is_same_type>(layernorm_src, phi::funcs::LayerNormBackward<T, U, is_same_type>(layernorm_src,
d_out, d_out,
gamma, gamma,
mean, mean,
variance, variance,
d_layernorm_src, d_layernorm_src,
d_scale, d_scale,
d_layernorm_bias, d_layernorm_bias,
epsilon_, epsilon_,
this->rows_, this->rows_,
this->cols_, this->cols_,
ctx); ctx);
this->ResidualDropoutBiasGrad( this->ResidualDropoutBiasGrad(
ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias); ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias);
} }
......
...@@ -23,9 +23,9 @@ limitations under the License. */ ...@@ -23,9 +23,9 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/layer_norm_kernel.h" #include "paddle/phi/kernels/layer_norm_kernel.h"
...@@ -37,7 +37,7 @@ USE_OP_ITSELF(dropout); ...@@ -37,7 +37,7 @@ USE_OP_ITSELF(dropout);
USE_OP_ITSELF(layer_norm); USE_OP_ITSELF(layer_norm);
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = phi::backends::gpu::CudnnDataType<T>;
template <typename T> template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType; using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
......
...@@ -15,12 +15,12 @@ limitations under the License. */ ...@@ -15,12 +15,12 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/matmul_v2_op.h" #include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/distributed/collective/process_group_nccl.h"
...@@ -120,7 +120,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -120,7 +120,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2); ctx, bsz_seq, d_model, dropout_param2, epsilon2);
using U = LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
const phi::DenseTensor* in = &x; const phi::DenseTensor* in = &x;
const U* ln1_scale_ptr = const U* ln1_scale_ptr =
...@@ -238,7 +238,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -238,7 +238,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
DropoutParam dropout_param1(context, 1); DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2); DropoutParam dropout_param2(context, 2);
using U = LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
dev_ctx.Alloc<T>(out, out->numel() * sizeof(T)); dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
dev_ctx.Alloc<uint8_t>(dropout1_mask, dev_ctx.Alloc<uint8_t>(dropout1_mask,
dropout1_mask->numel() * sizeof(uint8_t)); dropout1_mask->numel() * sizeof(uint8_t));
...@@ -369,7 +369,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -369,7 +369,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2); ctx, bsz_seq, d_model, dropout_param2, epsilon2);
using U = LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
const U* ln1_gamma_ptr = const U* ln1_gamma_ptr =
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>(); ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>(); const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
...@@ -485,7 +485,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -485,7 +485,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using U = LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
auto& dev_ctx = context.template device_context<phi::GPUContext>(); auto& dev_ctx = context.template device_context<phi::GPUContext>();
auto d_out = auto d_out =
*context.Input<phi::DenseTensor>(framework::GradVarName("Out")); *context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
......
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
#define LN_NUM_COLS 1024 #define LN_NUM_COLS 1024
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = phi::backends::gpu::CudnnDataType<T>;
template <typename T> template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType; using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
...@@ -174,8 +174,8 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -174,8 +174,8 @@ __global__ void FusedLayernormResidualDropoutBias(
relu); relu);
} }
mean_val = BlockReduceSum<U>(mean_val, shared_mean); mean_val = phi::funcs::BlockReduceSum<U>(mean_val, shared_mean);
var_val = BlockReduceSum<U>(var_val, shared_var); var_val = phi::funcs::BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>( auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
static_cast<float>(1.) / static_cast<float>(cols)); static_cast<float>(1.) / static_cast<float>(cols));
...@@ -189,7 +189,7 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -189,7 +189,7 @@ __global__ void FusedLayernormResidualDropoutBias(
__syncthreads(); __syncthreads();
mean_val = mean_share; mean_val = mean_share;
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon)); U invvar = phi::funcs::rsqrt_<U>(var_share + static_cast<U>(epsilon));
// calculate layernorm_dst // calculate layernorm_dst
CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(scale, CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(scale,
...@@ -358,8 +358,8 @@ __global__ void FusedLayernormResidualDropoutBiasInfer( ...@@ -358,8 +358,8 @@ __global__ void FusedLayernormResidualDropoutBiasInfer(
relu); relu);
} }
mean_val = BlockReduceSum<U>(mean_val, shared_mean); mean_val = phi::funcs::BlockReduceSum<U>(mean_val, shared_mean);
var_val = BlockReduceSum<U>(var_val, shared_var); var_val = phi::funcs::BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>( auto scale = static_cast<LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(
static_cast<float>(1.) / static_cast<float>(cols)); static_cast<float>(1.) / static_cast<float>(cols));
...@@ -372,7 +372,7 @@ __global__ void FusedLayernormResidualDropoutBiasInfer( ...@@ -372,7 +372,7 @@ __global__ void FusedLayernormResidualDropoutBiasInfer(
__syncthreads(); __syncthreads();
mean_val = mean_share; mean_val = mean_share;
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon)); U invvar = phi::funcs::rsqrt_<U>(var_share + static_cast<U>(epsilon));
// calculate layernorm_dst // calculate layernorm_dst
CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(scale, CalcLayernormY<T, VecSize, U, ScaleBiasWithSameTypeX>(scale,
...@@ -412,7 +412,7 @@ struct FusedLayernormResidualDropoutBiasFunctor { ...@@ -412,7 +412,7 @@ struct FusedLayernormResidualDropoutBiasFunctor {
LayerNormParamType<T> *mean, LayerNormParamType<T> *mean,
LayerNormParamType<T> *var, LayerNormParamType<T> *var,
cudaStream_t stream) { cudaStream_t stream) {
int blockDim = GetDesiredBlockDim(cols / VecSize); int blockDim = phi::funcs::GetDesiredBlockDim(cols / VecSize);
if (mean != nullptr && var != nullptr) { if (mean != nullptr && var != nullptr) {
LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T, LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T,
MaskType, MaskType,
...@@ -859,9 +859,9 @@ void LaunchLayernormResidualDropoutBias( ...@@ -859,9 +859,9 @@ void LaunchLayernormResidualDropoutBias(
mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream()));
} }
// call layernorm forward // call layernorm forward
switch (GetDesiredBlockDim(cols)) { switch (phi::funcs::GetDesiredBlockDim(cols)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U, kBlockDim, ScaleBiasWithSameTypeX> phi::funcs::LayerNormForward<T, U, kBlockDim, ScaleBiasWithSameTypeX>
<<<rows, kBlockDim, 0, ctx.stream()>>>( <<<rows, kBlockDim, 0, ctx.stream()>>>(
dst, dst,
scale, scale,
...@@ -1005,7 +1005,7 @@ void LaunchLayernormResidualDropoutBias( ...@@ -1005,7 +1005,7 @@ void LaunchLayernormResidualDropoutBias(
const int VecSize = MAX_CACHE_BYTES / sizeof(T); const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) { if (cols % VecSize != 0) {
int blockDim = GetDesiredBlockDim(cols); int blockDim = phi::funcs::GetDesiredBlockDim(cols);
LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T, LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T,
uint8_t, uint8_t,
1, 1,
...@@ -1043,7 +1043,7 @@ void LaunchLayernormResidualDropoutBias( ...@@ -1043,7 +1043,7 @@ void LaunchLayernormResidualDropoutBias(
break; break;
} }
} else { } else {
int blockDim = GetDesiredBlockDim(cols / VecSize); int blockDim = phi::funcs::GetDesiredBlockDim(cols / VecSize);
LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T, LaunchFusedLayernormResidualDropoutBiasCUDAKernel<T,
uint8_t, uint8_t,
VecSize, VecSize,
...@@ -1102,24 +1102,25 @@ void LaunchLayernormResidualDropoutGrad( ...@@ -1102,24 +1102,25 @@ void LaunchLayernormResidualDropoutGrad(
if (!is_upscale_in_train) { if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f); factor = static_cast<T>(1.0f);
} }
ln_bwd_fast_kernel_driver<T, phi::funcs::ln_bwd_fast_kernel_driver<
U, T,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, U,
MaskType>(dev_ctx, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>,
rows, MaskType>(dev_ctx,
cols, rows,
epsilon, cols,
layernorm_src, epsilon,
scale, layernorm_src,
mean, scale,
var, mean,
d_out, var,
d_residual, d_out,
d_scale, d_residual,
d_layernorm_bias, d_scale,
mask_data, d_layernorm_bias,
factor, mask_data,
d_dropout_src); factor,
d_dropout_src);
} }
} // namespace operators } // namespace operators
......
...@@ -235,7 +235,7 @@ struct TestFusedLayernormResidualDropoutBias { ...@@ -235,7 +235,7 @@ struct TestFusedLayernormResidualDropoutBias {
if (cols % 4 != 0) { if (cols % 4 != 0) {
VecSize = 1; VecSize = 1;
} }
int threads = paddle::operators::GetDesiredBlockDim(cols / VecSize); int threads = phi::funcs::GetDesiredBlockDim(cols / VecSize);
const int increment = ((cols - 1) / (threads * VecSize) + 1) * VecSize; const int increment = ((cols - 1) / (threads * VecSize) + 1) * VecSize;
T *bias_ptr = nullptr; T *bias_ptr = nullptr;
......
...@@ -24,17 +24,17 @@ namespace cub = hipcub; ...@@ -24,17 +24,17 @@ namespace cub = hipcub;
#include <iostream> #include <iostream>
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = phi::backends::gpu::CudnnDataType<T>;
template <typename T> template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType; using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
...@@ -331,6 +331,38 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -331,6 +331,38 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
} }
#endif #endif
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
T xLower = floor(x);
T xUpper = ceil(x);
// x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
// even.
T dLower = x - xLower;
T dUpper = xUpper - x;
return static_cast<T>(
(dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
? xLower
: xUpper);
}
template <typename T>
__forceinline__ __device__ int8_t quant_helper(const T input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * static_cast<float>(input);
if (round_type == 0) {
quant_value = static_cast<float>(roundWithTiesToEven(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
return static_cast<int8_t>(quant_value);
}
template <typename T, typename U, bool ScaleBiasWithSameTypeX> template <typename T, typename U, bool ScaleBiasWithSameTypeX>
using LayerNormScaleBiasT = using LayerNormScaleBiasT =
typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type; typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;
...@@ -947,17 +979,17 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, ...@@ -947,17 +979,17 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
// get temp space for dscale and dbias. // get temp space for dscale and dbias.
phi::DenseTensor dscale_temp; phi::DenseTensor dscale_temp;
dscale_temp.Resize({gridx, cols}); dscale_temp.Resize({gridx, cols});
dscale_temp.mutable_data<U>(dev_ctx.GetPlace()); dev_ctx.template Alloc<U>(&dscale_temp);
U *dscale_temp_ptr = dscale_temp.data<U>(); U *dscale_temp_ptr = dscale_temp.data<U>();
phi::DenseTensor dbias_temp; phi::DenseTensor dbias_temp;
dbias_temp.Resize({gridx, cols}); dbias_temp.Resize({gridx, cols});
dbias_temp.mutable_data<U>(dev_ctx.GetPlace()); dev_ctx.template Alloc<U>(&dbias_temp);
U *dbias_temp_ptr = dbias_temp.data<U>(); U *dbias_temp_ptr = dbias_temp.data<U>();
if (mask_ptr != nullptr) { if (mask_ptr != nullptr) {
if (d_dropout_src_ptr == nullptr) { if (d_dropout_src_ptr == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"To compute fused_dropout_residual_ln grad, d_dropout_src_ptr " "To compute fused_dropout_residual_ln grad, d_dropout_src_ptr "
"can't be null")); "can't be null"));
} }
...@@ -1069,8 +1101,8 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, ...@@ -1069,8 +1101,8 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
// #blocks: 32,#threads_per_block: 512 // #blocks: 32,#threads_per_block: 512
// Note: it is not supported for double type. // Note: it is not supported for double type.
if (sizeof(U) > 4) { if (sizeof(U) > 4) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(
"Only support float and fp16 type")); phi::errors::InvalidArgument("Only support float and fp16 type"));
} else { } else {
int gridx_2 = 0; int gridx_2 = 0;
...@@ -1103,7 +1135,7 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, ...@@ -1103,7 +1135,7 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
#undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL #undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
} }
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"Fast layer_norm kernel is only used when feature_size is 1024")); "Fast layer_norm kernel is only used when feature_size is 1024"));
} }
} }
...@@ -1891,11 +1923,11 @@ static void LayerNormBackward( ...@@ -1891,11 +1923,11 @@ static void LayerNormBackward(
constexpr int part_size = BDIMY2 * VPT; constexpr int part_size = BDIMY2 * VPT;
const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1); const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);
auto part_grad_gamma_ptr = memory::Alloc( auto part_grad_gamma_ptr = paddle::memory::Alloc(
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
part_size * feature_size * sizeof(U), part_size * feature_size * sizeof(U),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto part_grad_beta_ptr = memory::Alloc( auto part_grad_beta_ptr = paddle::memory::Alloc(
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
part_size * feature_size * sizeof(U), part_size * feature_size * sizeof(U),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
...@@ -1959,5 +1991,5 @@ static void LayerNormBackward( ...@@ -1959,5 +1991,5 @@ static void LayerNormBackward(
} }
} }
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/layer_norm_grad_kernel.h" #include "paddle/phi/kernels/layer_norm_grad_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h"
namespace phi { namespace phi {
...@@ -34,7 +34,7 @@ void LayerNormGradKernel(const Context &dev_ctx, ...@@ -34,7 +34,7 @@ void LayerNormGradKernel(const Context &dev_ctx,
DenseTensor *x_grad, DenseTensor *x_grad,
DenseTensor *scale_grad, DenseTensor *scale_grad,
DenseTensor *bias_grad) { DenseTensor *bias_grad) {
using U = paddle::operators::LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
// d_x, d_scale, d_bias may be nullptr // d_x, d_scale, d_bias may be nullptr
auto *d_x = x_grad; auto *d_x = x_grad;
auto *d_scale = scale_grad; auto *d_scale = scale_grad;
...@@ -84,7 +84,7 @@ void LayerNormGradKernel(const Context &dev_ctx, ...@@ -84,7 +84,7 @@ void LayerNormGradKernel(const Context &dev_ctx,
: dev_ctx.template Alloc<ScaleBiasT>(d_bias)); \ : dev_ctx.template Alloc<ScaleBiasT>(d_bias)); \
auto *d_x_data = \ auto *d_x_data = \
(d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x)); \ (d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x)); \
paddle::operators::LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \ phi::funcs::LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, \ x_data, \
d_y_data, \ d_y_data, \
scale_data, \ scale_data, \
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/layer_norm_kernel.h" #include "paddle/phi/kernels/layer_norm_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h"
namespace phi { namespace phi {
...@@ -36,9 +36,9 @@ void LayerNormDirectCUDAFunctor<T, U>::operator()(gpuStream_t stream, ...@@ -36,9 +36,9 @@ void LayerNormDirectCUDAFunctor<T, U>::operator()(gpuStream_t stream,
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]); int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]); int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
switch (paddle::operators::GetDesiredBlockDim(feature_size)) { switch (phi::funcs::GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
paddle::operators::LayerNormForward<T, U, kBlockDim> phi::funcs::LayerNormForward<T, U, kBlockDim>
<<<batch_size, kBlockDim, 0, stream>>>( <<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size)); input, scale, bias, output, mean, variance, eps, feature_size));
default: default:
...@@ -65,7 +65,7 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -65,7 +65,7 @@ void LayerNormKernel(const Context &dev_ctx,
DenseTensor *y, DenseTensor *y,
DenseTensor *mean, DenseTensor *mean,
DenseTensor *var) { DenseTensor *var) {
using U = paddle::operators::LayerNormParamType<T>; using U = phi::funcs::LayerNormParamType<T>;
auto *scale = scale_opt.get_ptr(); auto *scale = scale_opt.get_ptr();
auto *bias = bias_opt.get_ptr(); auto *bias = bias_opt.get_ptr();
...@@ -109,9 +109,9 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -109,9 +109,9 @@ void LayerNormKernel(const Context &dev_ctx,
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ #define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \ do { \
switch (paddle::operators::GetDesiredBlockDim(feature_size)) { \ switch (phi::funcs::GetDesiredBlockDim(feature_size)) { \
FIXED_BLOCK_DIM_CASE( \ FIXED_BLOCK_DIM_CASE( \
paddle::operators:: \ phi::funcs:: \
LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX> \ LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX> \
<<<batch_size, kBlockDim, 0, stream>>>( \ <<<batch_size, kBlockDim, 0, stream>>>( \
x_data, \ x_data, \
...@@ -140,13 +140,13 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -140,13 +140,13 @@ void LayerNormKernel(const Context &dev_ctx,
const int ROWS_PER_CTA = WARPS_M; \ const int ROWS_PER_CTA = WARPS_M; \
const int grid = static_cast<int>( \ const int grid = static_cast<int>( \
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA))); \ std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA))); \
paddle::operators::fast_ln_fwd_kernel<T, \ phi::funcs::fast_ln_fwd_kernel<T, \
U, \ U, \
ScaleT, \ ScaleT, \
VecSize, \ VecSize, \
WARPS_M, \ WARPS_M, \
WARPS_N, \ WARPS_N, \
BYTES_PER_LDG> \ BYTES_PER_LDG> \
<<<grid, THREADS_PER_CTA, 0, stream>>>( \ <<<grid, THREADS_PER_CTA, 0, stream>>>( \
batch_size, \ batch_size, \
feature_size, \ feature_size, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册