diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 0500f76110f33576c8a4d12be4c5d6c8376f0329..db83d6cda45e39c8ea8a972bcbd019c0b6392be8 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -1530,21 +1530,16 @@ class CublasFusedMLP { beta16 = add_residual ? static_cast(1.0) : static_cast(0.0); - void *alpha = nullptr, *beta = nullptr; + void *alpha = &alpha32, *beta = &beta32; if (std::is_same::value) { alpha = &alpha64; beta = &beta64; - } else if (std::is_same::value) { - alpha = &alpha64; - beta = &beta64; - } else if (std::is_same::value) { + } + + if (std::is_same::value && + FLAGS_gemm_use_half_precision_compute_type) { alpha = &alpha16; beta = &beta16; - } else { - PADDLE_ENFORCE_EQ(true, - false, - platform::errors::InvalidArgument( - "Only support double, float, half data type. ")); } const auto *x_data = x->data();