From 9a849a371f96bef24d63061bec04b2b7536b777c Mon Sep 17 00:00:00 2001 From: Wang Bojun <105858416+wwbitejotunn@users.noreply.github.com> Date: Sun, 9 Oct 2022 13:54:04 +0800 Subject: [PATCH] fix fp16 (#46713) * fix fp16 * remove debug info * code style refine --- .../plugin/preln_residual_bias_plugin.cu | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu index a56a85d6080..2735a22a14f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu @@ -80,6 +80,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2( int m, int n, float epsilon) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) __shared__ float s_mean; __shared__ float s_variance; float x_sum = 0.0f; @@ -138,6 +139,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2( } normed_output[index] = val; } +#endif } #endif @@ -418,6 +420,7 @@ int PrelnResidualBiasPluginDynamic::enqueue( float *var = nullptr; const int VecSize = 8; // if odd +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) if (hidden & 1 == 0) { int half_n = hidden / 2; int half_n_32 = (half_n + 31) / 32 * 32; @@ -459,6 +462,32 @@ int PrelnResidualBiasPluginDynamic::enqueue( var, stream); } +#else + paddle::operators::FusedLayernormResidualDropoutBiasFunctor()( + rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + src, + residual, + bias, + scale, + layernorm_bias, + mask_data, + dst, + layernorm_dst, + mean, + var, + stream); +#endif #else PADDLE_THROW(platform::errors::Fatal( "The Ernie(Bert) tensorRT plugin should be " -- GitLab