From 84b8baf1967e327712269e7632235438d09759d9 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 2 Oct 2017 15:50:24 -0700 Subject: [PATCH] gather scatter with cuda streams --- paddle/operators/gather.cu.h | 13 ++++++++----- paddle/operators/gather_op.cu | 5 ++--- paddle/operators/scatter.cu.h | 10 ++++++---- paddle/operators/scatter_op.cu | 4 ++-- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/paddle/operators/gather.cu.h b/paddle/operators/gather.cu.h index b400c104407..2ae11376a2b 100644 --- a/paddle/operators/gather.cu.h +++ b/paddle/operators/gather.cu.h @@ -46,9 +46,9 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, * return: output tensor */ template -void GPUGather(const Place& place, const Tensor* src, const Tensor* index, - Tensor* output) { - PADDLE_ENFORCE(platform::is_gpu_place(place)); +void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, + const Tensor* index, Tensor* output) { + // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; @@ -68,8 +68,11 @@ void GPUGather(const Place& place, const Tensor* src, const Tensor* index, int block = 512; int n = slice_size * index_size; int grid = (n + block - 1) / block; - GatherCUDAKernel<<>>(p_src, p_index, p_output, index_size, - slice_size); + + GatherCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_src, p_index, p_output, index_size, slice_size); } } // namespace operators diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu index 06004614b2c..9937be5915d 100644 --- a/paddle/operators/gather_op.cu +++ b/paddle/operators/gather_op.cu @@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - GPUGather(ctx.GetPlace(), x, index, output); + GPUGather(ctx.device_context(), x, index, output); } }; @@ -42,7 +42,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - LOG(INFO) << "Gather grad here"; auto *Index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); @@ -53,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - GPUScatterAssign(ctx.GetPlace(), dO, Index, dX); + GPUScatterAssign(ctx.device_context(), dO, Index, dX); } }; diff --git a/paddle/operators/scatter.cu.h b/paddle/operators/scatter.cu.h index add4791a793..f4a3965d94c 100644 --- a/paddle/operators/scatter.cu.h +++ b/paddle/operators/scatter.cu.h @@ -45,11 +45,11 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices, * return: output tensor */ template -void GPUScatterAssign(const platform::Place& place, +void GPUScatterAssign(const platform::DeviceContext& ctx, const paddle::framework::Tensor* src, const paddle::framework::Tensor* index, paddle::framework::Tensor* output) { - PADDLE_ENFORCE(platform::is_gpu_place(place)); + // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; @@ -70,8 +70,10 @@ void GPUScatterAssign(const platform::Place& place, int n = slice_size * index_size; int grid = (n + block - 1) / block; - ScatterCUDAKernel<<>>(p_src, p_index, p_output, index_size, - slice_size); + ScatterCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_src, p_index, p_output, index_size, slice_size); } } // namespace operators diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu index 831eabdae4f..6d13a876f98 100644 --- a/paddle/operators/scatter_op.cu +++ b/paddle/operators/scatter_op.cu @@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel { Out->ShareDataWith(*Ref); - GPUScatterAssign(ctx.GetPlace(), Updates, Index, Out); + GPUScatterAssign(ctx.device_context(), Updates, Index, Out); } }; @@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Index] - GPUGather(ctx.GetPlace(), dOut, Index, dUpdates); + GPUGather(ctx.device_context(), dOut, Index, dUpdates); } }; -- GitLab