未验证 提交 c1cadcca 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

fix scale type in alpha and beta (#48887)

上级 c7d6d9f4
......@@ -1530,21 +1530,16 @@ class CublasFusedMLP {
beta16 =
add_residual ? static_cast<half>(1.0) : static_cast<half>(0.0);
void *alpha = nullptr, *beta = nullptr;
void *alpha = &alpha32, *beta = &beta32;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else if (std::is_same<T, float>::value) {
alpha = &alpha64;
beta = &beta64;
} else if (std::is_same<T, phi::dtype::float16>::value) {
}
if (std::is_same<T, phi::dtype::float16>::value &&
FLAGS_gemm_use_half_precision_compute_type) {
alpha = &alpha16;
beta = &beta16;
} else {
PADDLE_ENFORCE_EQ(true,
false,
platform::errors::InvalidArgument(
"Only support double, float, half data type. "));
}
const auto *x_data = x->data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册