diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 46c5f7c0e5f94138cfd3a8cdf04758786d2ce517..708aef3d690f9796da2a766bfede5a18d5c768f3 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" #include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/layer_norm_kernel.h" DECLARE_bool(use_fast_math); @@ -347,6 +348,18 @@ class FusedDropoutHelper { DropoutParam dropout_param_; }; +template +struct PDDataTypeTraits { + using DataType = T; +}; + +template <> +struct PDDataTypeTraits { + // Since LayerNormDirectCUDAFunctor register half type, we need to convert + // phi::float16 to half. + using DataType = half; +}; + template * mean, LayerNormParamType* variance) { - using U = LayerNormParamType; - switch (GetDesiredBlockDim(this->cols_)) { - FIXED_BLOCK_DIM_CASE( - LayerNormForward - <<rows_, kBlockDim, 0, ctx.stream()>>>( - src, gamma, beta, out, mean, variance, epsilon_, this->cols_)); - } + using InDataType = typename PDDataTypeTraits::DataType; + using OutDataType = typename PDDataTypeTraits::DataType; + + phi::LayerNormDirectCUDAFunctor> + layer_norm; + std::vector src_shape{this->rows_, this->cols_}; + layer_norm(ctx.stream(), + reinterpret_cast(src), + src_shape, + beta, + gamma, + reinterpret_cast(out), + mean, + variance, + 1, + epsilon_); } void LayerNormGrad(const phi::GPUContext& ctx, diff --git a/paddle/phi/kernels/gpu/layer_norm_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_kernel.cu index 1dd1070884732e54d1691a86c4431c5136499993..1350cb2209c3187978299c184d1da97c5060cec4 100644 --- a/paddle/phi/kernels/gpu/layer_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_kernel.cu @@ -50,6 +50,7 @@ void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, } template class LayerNormDirectCUDAFunctor; +template class LayerNormDirectCUDAFunctor; #if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) template class LayerNormDirectCUDAFunctor; #endif