diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index aaea0b66ff5e495c87931a9f0272edb1af2c2393..22340210b5715d6df0824359c6cc13c93ecdd31f 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -46,7 +46,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { << " , activation = " << activation; bool enable_auxiliary = reserve_space == nullptr ? false : true; - dev_ctx->Alloc(out, out->numel() * sizeof(T)); + dev_ctx.Alloc(out, out->numel() * sizeof(T)); auto* out_data = out->data(); auto x_mat_dims = @@ -110,7 +110,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { } else { reserve_space_size = phi::product(out->dims()) * sizeof(T); } - dev_ctx->Alloc(reserve_space, out->type(), reserve_space_size); + dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size); void* aux_data = reinterpret_cast(reserve_space->data()); PADDLE_ENFORCE_GPU_SUCCESS( @@ -492,7 +492,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { workspace_size, phi::Stream(reinterpret_cast(dev_ctx.stream()))); - auto* dx_data = dev_ctx->Alloc(dx, dx->numel() * sizeof(T)); + auto* dx_data = dev_ctx.Alloc(dx, dx->numel() * sizeof(T)); const auto* y_data = y->data(); const auto* dout_data = dout->data(); const auto* a_data = kXGradAIsDZ ? dout_data : y_data; @@ -600,7 +600,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { sizeof(epiloque_func_for_dy))); if (dbias) { - auto* dbias_data = dev_ctx->Alloc(dbias, dbias->numel() * sizeof(T)); + auto* dbias_data = dev_ctx.Alloc(dbias, dbias->numel() * sizeof(T)); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( dy_operation_desc, @@ -613,7 +613,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { dev_ctx.GetPlace(), workspace_size, phi::Stream(reinterpret_cast(dev_ctx.stream()))); - auto* dy_data = dev_ctx->Alloc(dy, dy->numel() * sizeof(T)); + auto* dy_data = dev_ctx.Alloc(dy, dy->numel() * sizeof(T)); const auto* dout_data = dout->data(); const auto* x_data = x->data(); const auto* a_data = kYGradAIsDZ ? dout_data : x_data;