未验证 提交 01c26ab2 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix fc kernel diff (#49781)

* fix fc kernel diff

* disable fc_elementwise_layernorm_fuse_pass
上级 8a934047
......@@ -171,8 +171,9 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"gpu_cpu_map_matmul_to_mul_pass",
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
// "fc_elementwise_layernorm_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass",
"runtime_context_cache_pass",
};
......
......@@ -276,8 +276,8 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data,
half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i);
half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0;
#else
half tmp_0 = static_cast<half>(static_cast<float>(save_ptr[save_index]) -
static_cast<float>(mean_i) /
half tmp_0 = static_cast<half>((static_cast<float>(save_ptr[save_index]) -
static_cast<float>(mean_i)) /
static_cast<float>(std_i));
half tmp_1 = scale ? static_cast<half>(static_cast<float>(scale[j]) *
static_cast<float>(tmp_0))
......@@ -394,19 +394,16 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
auto* out_data = dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
blas.GEMM(false,
false,
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1.0),
x_data,
K,
w_data,
N,
static_cast<T>(0.0),
out_data,
N);
out_data);
auto* y = ctx.Input<framework::Tensor>("Y");
auto* bias_0 = ctx.Input<framework::Tensor>("Bias0");
auto* bias_1 = ctx.Input<framework::Tensor>("Bias1");
......
......@@ -292,19 +292,16 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
errors::PermissionDenied(
"Weight padding in fc can not be used in GPU scope."));
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
blas.GEMM(false,
false,
blas.GEMM(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
static_cast<T>(1.0),
X,
K,
W,
N,
static_cast<T>(0.0),
Y,
N);
Y);
if (B == NULL) {
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册