未验证 提交 0b24d167 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix fc and fused_fc_elementwise_layernorm kernel diff (#49778)

上级 5d60ff91
...@@ -276,9 +276,9 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data, ...@@ -276,9 +276,9 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data,
half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i); half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i);
half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0; half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0;
#else #else
half tmp_0 = static_cast<half>(static_cast<float>(save_ptr[save_index]) - half tmp_0 = static_cast<half>((static_cast<float>(save_ptr[save_index]) -
static_cast<float>(mean_i) / static_cast<float>(mean_i)) /
static_cast<float>(std_i)); static_cast<float>(std_i));
half tmp_1 = scale ? static_cast<half>(static_cast<float>(scale[j]) * half tmp_1 = scale ? static_cast<half>(static_cast<float>(scale[j]) *
static_cast<float>(tmp_0)) static_cast<float>(tmp_0))
: tmp_0; : tmp_0;
...@@ -394,19 +394,16 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> { ...@@ -394,19 +394,16 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
auto* out_data = dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T)); auto* out_data = dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
blas.GEMM(false, blas.GEMM(CblasNoTrans,
false, CblasNoTrans,
M, M,
N, N,
K, K,
static_cast<T>(1.0), static_cast<T>(1.0),
x_data, x_data,
K,
w_data, w_data,
N,
static_cast<T>(0.0), static_cast<T>(0.0),
out_data, out_data);
N);
auto* y = ctx.Input<phi::DenseTensor>("Y"); auto* y = ctx.Input<phi::DenseTensor>("Y");
auto* bias_0 = ctx.Input<phi::DenseTensor>("Bias0"); auto* bias_0 = ctx.Input<phi::DenseTensor>("Bias0");
auto* bias_1 = ctx.Input<phi::DenseTensor>("Bias1"); auto* bias_1 = ctx.Input<phi::DenseTensor>("Bias1");
......
...@@ -345,19 +345,16 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context, ...@@ -345,19 +345,16 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
errors::PermissionDenied( errors::PermissionDenied(
"Weight padding in fc can not be used in GPU scope.")); "Weight padding in fc can not be used in GPU scope."));
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context); auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, blas.GEMM(CblasNoTrans,
false, CblasNoTrans,
M, M,
N, N,
K, K,
static_cast<T>(1.0), static_cast<T>(1.0),
X, X,
K,
W, W,
N,
static_cast<T>(0.0), static_cast<T>(0.0),
Y, Y);
N);
if (B == NULL) { if (B == NULL) {
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册