未验证 提交 cdab3a44 编写于 作者: S ShenLiang 提交者: GitHub

Fix nullptr to TestFuseGemmEpilogueReluBWDFP* (#48997) (#49090)

Co-authored-by: NMing-Xu Huang <mingh@nvidia.com>
上级 ddcd1b61
...@@ -139,9 +139,8 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { ...@@ -139,9 +139,8 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
// Note (Ming Huang): Reserve space of relu is a bit-mask,
// which cannot pass nan_and_inf checking if shape is set. if (ctx->HasOutput("ReserveSpace")) {
if (activation == "gelu" && ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims)); ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims));
} }
} }
......
...@@ -107,15 +107,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -107,15 +107,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
sizeof(bias_data))); sizeof(bias_data)));
if (enable_auxiliary && activation != "none") { if (enable_auxiliary && activation != "none") {
size_t reserve_space_size = 0; // Note (Ming Huang): The initialization of ReseveSpace is happened in the
// dev_ctx.Alloc. Therefore, we set real date type up here.
if (activation == "relu") { if (activation == "relu") {
// Count in bits. paddle::experimental::DataType rs_type =
reserve_space_size = phi::product(out->dims()) / 8; paddle::experimental::DataType::BOOL;
size_t reserve_space_size =
phi::product(reserve_space->dims()) * SizeOf(rs_type);
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
} else { } else {
reserve_space_size = phi::product(out->dims()) * sizeof(T); size_t reserve_space_size =
phi::product(reserve_space->dims()) * sizeof(T);
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
} }
dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>()); void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
...@@ -185,7 +191,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -185,7 +191,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
stream, stream,
workspace->ptr(), workspace->ptr(),
workspace_size); workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle, platform::dynload::cublasLtMatmul(lt_handle,
operation_desc, operation_desc,
...@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(epiloque_func_for_dx))); sizeof(epiloque_func_for_dx)));
if (activation_grad != "none") { if (activation_grad != "none") {
auto* aux_data = reserve_space->data<T>(); auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, dx_operation_desc,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册