未验证 提交 569d6c5b 编写于 作者: S sneaxiy 提交者: GitHub

fix fused_gemm_epilogue_op compile error (#45862)

上级 fc66fdb7
...@@ -46,7 +46,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -46,7 +46,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
<< " , activation = " << activation; << " , activation = " << activation;
bool enable_auxiliary = reserve_space == nullptr ? false : true; bool enable_auxiliary = reserve_space == nullptr ? false : true;
dev_ctx->Alloc<T>(out, out->numel() * sizeof(T)); dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
auto* out_data = out->data<T>(); auto* out_data = out->data<T>();
auto x_mat_dims = auto x_mat_dims =
...@@ -110,7 +110,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -110,7 +110,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
} else { } else {
reserve_space_size = phi::product(out->dims()) * sizeof(T); reserve_space_size = phi::product(out->dims()) * sizeof(T);
} }
dev_ctx->Alloc(reserve_space, out->type(), reserve_space_size); dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>()); void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -492,7 +492,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -492,7 +492,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
workspace_size, workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dev_ctx->Alloc<T>(dx, dx->numel() * sizeof(T)); auto* dx_data = dev_ctx.Alloc<T>(dx, dx->numel() * sizeof(T));
const auto* y_data = y->data<T>(); const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>(); const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data; const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
...@@ -600,7 +600,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -600,7 +600,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(epiloque_func_for_dy))); sizeof(epiloque_func_for_dy)));
if (dbias) { if (dbias) {
auto* dbias_data = dev_ctx->Alloc<T>(dbias, dbias->numel() * sizeof(T)); auto* dbias_data = dev_ctx.Alloc<T>(dbias, dbias->numel() * sizeof(T));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute( platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc, dy_operation_desc,
...@@ -613,7 +613,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -613,7 +613,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
workspace_size, workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dev_ctx->Alloc<T>(dy, dy->numel() * sizeof(T)); auto* dy_data = dev_ctx.Alloc<T>(dy, dy->numel() * sizeof(T));
const auto* dout_data = dout->data<T>(); const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data; const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册