未验证 提交 e3e92c9a 编写于 作者: L Leo Chen 提交者: GitHub

update alloc usage (#45654)

上级 def71d38
......@@ -150,8 +150,10 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx, workspace_size);
memory::allocation::AllocationPtr workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
double alpha64 = 1.0, beta64 = 0.0;
float alpha32 = 1.0f, beta32 = 0.0f;
......@@ -486,7 +488,10 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(aux_ld)));
}
auto dx_workspace = memory::Alloc(dev_ctx, workspace_size);
auto dx_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const auto* y_data = y->data<T>();
......@@ -605,7 +610,10 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(dbias_data)));
}
auto dy_workspace = memory::Alloc(dev_ctx, workspace_size);
auto dy_workspace = memory::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>();
......
......@@ -207,7 +207,10 @@ struct MatrixEighFunctor<phi::GPUContext, T> {
auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
auto values_stride = dims[dim_size - 1];
int lwork = 0;
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size);
auto info = memory::Alloc(
dev_ctx.GetPlace(),
sizeof(int) * batch_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto *info_ptr = reinterpret_cast<int *>(info->ptr());
// When the input type is float32, and the feature value input dimension is
......@@ -240,7 +243,10 @@ struct MatrixEighFunctor<phi::GPUContext, T> {
out_value,
&lwork);
}
auto work = memory::Alloc(dev_ctx, sizeof(T) * lwork);
auto work = memory::Alloc(
dev_ctx.GetPlace(),
sizeof(T) * lwork,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto *work_ptr = reinterpret_cast<T *>(work->ptr());
for (auto i = 0; i < batch_size; i++) {
auto *input_data = input_vector + i * vector_stride;
......
......@@ -522,7 +522,10 @@ void DotSdd(const phi::GPUContext& ctx,
gpu_type,
CUSPARSE_SDDMM_ALG_DEFAULT,
&buffer_size);
auto d_buffer_ptr = paddle::memory::Alloc(ctx, buffer_size);
auto d_buffer_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
buffer_size,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
void* d_buffer = static_cast<void*>(d_buffer_ptr->ptr());
platform::dynload::cusparseSDDMM(handle,
......@@ -616,7 +619,10 @@ void DotDsd(const phi::GPUContext& ctx,
gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT,
&buffer_size);
auto d_buffer_ptr = paddle::memory::Alloc(ctx, buffer_size);
auto d_buffer_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
buffer_size,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
void* d_buffer = static_cast<void*>(d_buffer_ptr->ptr());
platform::dynload::cusparseSpMM(handle,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册