未验证 提交 9a849a37 编写于 作者: W Wang Bojun 提交者: GitHub

fix fp16 (#46713)

* fix fp16

* remove debug info

* code style refine
上级 888223b7
......@@ -80,6 +80,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(
int m,
int n,
float epsilon) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
__shared__ float s_mean;
__shared__ float s_variance;
float x_sum = 0.0f;
......@@ -138,6 +139,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(
}
normed_output[index] = val;
}
#endif
}
#endif
......@@ -418,6 +420,7 @@ int PrelnResidualBiasPluginDynamic::enqueue(
float *var = nullptr;
const int VecSize = 8;
// if odd
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
if (hidden & 1 == 0) {
int half_n = hidden / 2;
int half_n_32 = (half_n + 31) / 32 * 32;
......@@ -459,6 +462,32 @@ int PrelnResidualBiasPluginDynamic::enqueue(
var,
stream);
}
#else
paddle::operators::FusedLayernormResidualDropoutBiasFunctor<half,
uint8_t,
VecSize,
float,
false>()(
rows,
cols,
seed,
dropout_prob,
is_upscale_in_train,
is_test,
increment,
epsilon,
src,
residual,
bias,
scale,
layernorm_bias,
mask_data,
dst,
layernorm_dst,
mean,
var,
stream);
#endif
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册