未验证 提交 569d6c5b 编写于 作者: S sneaxiy 提交者: GitHub

fix fused_gemm_epilogue_op compile error (#45862)

上级 fc66fdb7
......@@ -46,7 +46,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
<< " , activation = " << activation;
bool enable_auxiliary = reserve_space == nullptr ? false : true;
dev_ctx->Alloc<T>(out, out->numel() * sizeof(T));
dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
auto* out_data = out->data<T>();
auto x_mat_dims =
......@@ -110,7 +110,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
} else {
reserve_space_size = phi::product(out->dims()) * sizeof(T);
}
dev_ctx->Alloc(reserve_space, out->type(), reserve_space_size);
dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());
PADDLE_ENFORCE_GPU_SUCCESS(
......@@ -492,7 +492,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dev_ctx->Alloc<T>(dx, dx->numel() * sizeof(T));
auto* dx_data = dev_ctx.Alloc<T>(dx, dx->numel() * sizeof(T));
const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
......@@ -600,7 +600,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(epiloque_func_for_dy)));
if (dbias) {
auto* dbias_data = dev_ctx->Alloc<T>(dbias, dbias->numel() * sizeof(T));
auto* dbias_data = dev_ctx.Alloc<T>(dbias, dbias->numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
......@@ -613,7 +613,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dev_ctx->Alloc<T>(dy, dy->numel() * sizeof(T));
auto* dy_data = dev_ctx.Alloc<T>(dy, dy->numel() * sizeof(T));
const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册