From c1cadcca3513ac0be462f740046074ea7593f1db Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Fri, 9 Dec 2022 15:44:03 +0800 Subject: [PATCH] fix scale type in alpha and beta (#48887) --- .../fused/fused_multi_transformer_op.cu.h | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 0500f76110..db83d6cda4 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -1530,21 +1530,16 @@ class CublasFusedMLP { beta16 = add_residual ? static_cast(1.0) : static_cast(0.0); - void *alpha = nullptr, *beta = nullptr; + void *alpha = &alpha32, *beta = &beta32; if (std::is_same::value) { alpha = &alpha64; beta = &beta64; - } else if (std::is_same::value) { - alpha = &alpha64; - beta = &beta64; - } else if (std::is_same::value) { + } + + if (std::is_same::value && + FLAGS_gemm_use_half_precision_compute_type) { alpha = &alpha16; beta = &beta16; - } else { - PADDLE_ENFORCE_EQ(true, - false, - platform::errors::InvalidArgument( - "Only support double, float, half data type. ")); } const auto *x_data = x->data(); -- GitLab