From b851515b16d179f35410836a17f855b9b6a9c268 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Thu, 28 Sep 2017 15:41:20 -0700 Subject: [PATCH] merge new op grammar --- paddle/operators/gather_op.cu | 9 ++------- paddle/operators/scatter_op.cu | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu index f3ed692666..f7533cdd64 100644 --- a/paddle/operators/gather_op.cu +++ b/paddle/operators/gather_op.cu @@ -20,13 +20,8 @@ namespace paddle { namespace operators { -// template -__global__ void print_arr(const float *params, const int N) { - CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); } -} - template -class GatherOpCUDAKernel : public framework::OpKernel { +class GatherOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -42,7 +37,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { }; template -class GatherGradOpCUDAKernel : public framework::OpKernel { +class GatherGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu index e27a926c6a..89d23945e0 100644 --- a/paddle/operators/scatter_op.cu +++ b/paddle/operators/scatter_op.cu @@ -20,7 +20,7 @@ namespace paddle { namespace operators { template -class ScatterOpCUDAKernel : public framework::OpKernel { +class ScatterOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -37,7 +37,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel { }; template -class ScatterGradOpCUDAKernel : public framework::OpKernel { +class ScatterGradOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), -- GitLab