未验证 提交 e3e92c9a 编写于 作者: L Leo Chen 提交者: GitHub

update alloc usage (#45654)

上级 def71d38
...@@ -150,8 +150,10 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -150,8 +150,10 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code. // "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024; size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::allocation::AllocationPtr workspace = memory::Alloc(
memory::Alloc(dev_ctx, workspace_size); dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
double alpha64 = 1.0, beta64 = 0.0; double alpha64 = 1.0, beta64 = 0.0;
float alpha32 = 1.0f, beta32 = 0.0f; float alpha32 = 1.0f, beta32 = 0.0f;
...@@ -486,7 +488,10 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -486,7 +488,10 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(aux_ld))); sizeof(aux_ld)));
} }
auto dx_workspace = memory::Alloc(dev_ctx, workspace_size); auto dx_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const auto* y_data = y->data<T>(); const auto* y_data = y->data<T>();
...@@ -605,7 +610,10 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -605,7 +610,10 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(dbias_data))); sizeof(dbias_data)));
} }
auto dy_workspace = memory::Alloc(dev_ctx, workspace_size); auto dy_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace()); auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
const auto* dout_data = dout->data<T>(); const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
......
...@@ -207,7 +207,10 @@ struct MatrixEighFunctor<phi::GPUContext, T> { ...@@ -207,7 +207,10 @@ struct MatrixEighFunctor<phi::GPUContext, T> {
auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
auto values_stride = dims[dim_size - 1]; auto values_stride = dims[dim_size - 1];
int lwork = 0; int lwork = 0;
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size); auto info = memory::Alloc(
dev_ctx.GetPlace(),
sizeof(int) * batch_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto *info_ptr = reinterpret_cast<int *>(info->ptr()); auto *info_ptr = reinterpret_cast<int *>(info->ptr());
// When the input type is float32, and the feature value input dimension is // When the input type is float32, and the feature value input dimension is
...@@ -240,7 +243,10 @@ struct MatrixEighFunctor<phi::GPUContext, T> { ...@@ -240,7 +243,10 @@ struct MatrixEighFunctor<phi::GPUContext, T> {
out_value, out_value,
&lwork); &lwork);
} }
auto work = memory::Alloc(dev_ctx, sizeof(T) * lwork); auto work = memory::Alloc(
dev_ctx.GetPlace(),
sizeof(T) * lwork,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto *work_ptr = reinterpret_cast<T *>(work->ptr()); auto *work_ptr = reinterpret_cast<T *>(work->ptr());
for (auto i = 0; i < batch_size; i++) { for (auto i = 0; i < batch_size; i++) {
auto *input_data = input_vector + i * vector_stride; auto *input_data = input_vector + i * vector_stride;
......
...@@ -522,7 +522,10 @@ void DotSdd(const phi::GPUContext& ctx, ...@@ -522,7 +522,10 @@ void DotSdd(const phi::GPUContext& ctx,
gpu_type, gpu_type,
CUSPARSE_SDDMM_ALG_DEFAULT, CUSPARSE_SDDMM_ALG_DEFAULT,
&buffer_size); &buffer_size);
auto d_buffer_ptr = paddle::memory::Alloc(ctx, buffer_size); auto d_buffer_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
buffer_size,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
void* d_buffer = static_cast<void*>(d_buffer_ptr->ptr()); void* d_buffer = static_cast<void*>(d_buffer_ptr->ptr());
platform::dynload::cusparseSDDMM(handle, platform::dynload::cusparseSDDMM(handle,
...@@ -616,7 +619,10 @@ void DotDsd(const phi::GPUContext& ctx, ...@@ -616,7 +619,10 @@ void DotDsd(const phi::GPUContext& ctx,
gpu_type, gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT, CUSPARSE_SPMM_ALG_DEFAULT,
&buffer_size); &buffer_size);
auto d_buffer_ptr = paddle::memory::Alloc(ctx, buffer_size); auto d_buffer_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
buffer_size,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
void* d_buffer = static_cast<void*>(d_buffer_ptr->ptr()); void* d_buffer = static_cast<void*>(d_buffer_ptr->ptr());
platform::dynload::cusparseSpMM(handle, platform::dynload::cusparseSpMM(handle,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册