diff --git a/paddle/operators/gather.cu.h b/paddle/operators/gather.cu.h index b400c104407d3c7eb602b1b7535e4aa9bc199585..2ae11376a2ba9d2a140ad468e263cf06f25417ca 100644 --- a/paddle/operators/gather.cu.h +++ b/paddle/operators/gather.cu.h @@ -46,9 +46,9 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, * return: output tensor */ template -void GPUGather(const Place& place, const Tensor* src, const Tensor* index, - Tensor* output) { - PADDLE_ENFORCE(platform::is_gpu_place(place)); +void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, + const Tensor* index, Tensor* output) { + // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; @@ -68,8 +68,11 @@ void GPUGather(const Place& place, const Tensor* src, const Tensor* index, int block = 512; int n = slice_size * index_size; int grid = (n + block - 1) / block; - GatherCUDAKernel<<>>(p_src, p_index, p_output, index_size, - slice_size); + + GatherCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_src, p_index, p_output, index_size, slice_size); } } // namespace operators diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu index 06004614b2c6d6b2ec61dd45e5df0bab52dba922..9937be5915d38ea55507fbabce4b5ce6b2f9618f 100644 --- a/paddle/operators/gather_op.cu +++ b/paddle/operators/gather_op.cu @@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - GPUGather(ctx.GetPlace(), x, index, output); + GPUGather(ctx.device_context(), x, index, output); } }; @@ -42,7 +42,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - LOG(INFO) << "Gather grad here"; auto *Index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); @@ -53,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - GPUScatterAssign(ctx.GetPlace(), dO, Index, dX); + GPUScatterAssign(ctx.device_context(), dO, Index, dX); } }; diff --git a/paddle/operators/scatter.cu.h b/paddle/operators/scatter.cu.h index add4791a793a84f8271b051049e7e16c4d66fcc5..f4a3965d94c11fe53e72d4c6df5a6f88ccc420fb 100644 --- a/paddle/operators/scatter.cu.h +++ b/paddle/operators/scatter.cu.h @@ -45,11 +45,11 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices, * return: output tensor */ template -void GPUScatterAssign(const platform::Place& place, +void GPUScatterAssign(const platform::DeviceContext& ctx, const paddle::framework::Tensor* src, const paddle::framework::Tensor* index, paddle::framework::Tensor* output) { - PADDLE_ENFORCE(platform::is_gpu_place(place)); + // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; @@ -70,8 +70,10 @@ void GPUScatterAssign(const platform::Place& place, int n = slice_size * index_size; int grid = (n + block - 1) / block; - ScatterCUDAKernel<<>>(p_src, p_index, p_output, index_size, - slice_size); + ScatterCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_src, p_index, p_output, index_size, slice_size); } } // namespace operators diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu index 831eabdae4f6567c74bf690fd0655de049dea2b6..6d13a876f98a400f8b0b8cdc0f3ecf297895e335 100644 --- a/paddle/operators/scatter_op.cu +++ b/paddle/operators/scatter_op.cu @@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel { Out->ShareDataWith(*Ref); - GPUScatterAssign(ctx.GetPlace(), Updates, Index, Out); + GPUScatterAssign(ctx.device_context(), Updates, Index, Out); } }; @@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Index] - GPUGather(ctx.GetPlace(), dOut, Index, dUpdates); + GPUGather(ctx.device_context(), dOut, Index, dUpdates); } };