From 2d876b864395513de8db52db944ee5e8150d2730 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Tue, 3 Oct 2017 10:54:22 -0700 Subject: [PATCH] gather scatter fix according to google style --- paddle/operators/cond_op.cc | 4 ++-- paddle/operators/gather.cu.h | 14 +++++++------- paddle/operators/gather.h | 18 +++++++++--------- paddle/operators/gather_op.cu | 4 ++-- paddle/operators/gather_op.h | 4 ++-- paddle/operators/gather_test.cc | 2 +- paddle/operators/scatter.cu.h | 18 +++++++++--------- paddle/operators/scatter.h | 16 +++++++--------- paddle/operators/scatter_op.cu | 4 ++-- paddle/operators/scatter_op.h | 4 ++-- paddle/operators/scatter_test.cc | 2 +- 11 files changed, 44 insertions(+), 46 deletions(-) diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 7d7f1ba3b1..2737104a20 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -126,7 +126,7 @@ void CondOp::PrepareDataForSubnet( dim[0] = index_tensors[i].dims()[0]; tensor_child->mutable_data(dim, platform::CPUPlace()); - CPUGather(dev_ctx, tensor_parent, &index_tensors[i], tensor_child); + CPUGather(dev_ctx, *tensor_parent, index_tensors[i], tensor_child); } } @@ -187,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, Variable* var_child = sub_scopes[i]->FindVar(output); PADDLE_ENFORCE_NOT_NULL(var_child); auto* tensor_child = &var_child->Get(); - ScatterAssign(dev_ctx, tensor_child, &index_tensors[i], + ScatterAssign(dev_ctx, *tensor_child, index_tensors[i], tensor_parent); } } diff --git a/paddle/operators/gather.cu.h b/paddle/operators/gather.cu.h index 2ae11376a2..8d04ecd284 100644 --- a/paddle/operators/gather.cu.h +++ b/paddle/operators/gather.cu.h @@ -46,14 +46,14 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, * return: output tensor */ template -void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, - const Tensor* index, Tensor* output) { +void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); framework::DDim output_dims(src_dims); output_dims[0] = index_size; @@ -61,8 +61,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); int block = 512; diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 1e39a6da27..052db49cb3 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -24,6 +24,8 @@ limitations under the License. */ namespace paddle { namespace operators { +using framework::Tensor; + /** * A thin wrapper for gathering on cpu tensor * Return a new tensor from source tensor, gathered according to index @@ -32,21 +34,19 @@ namespace operators { * return: output tensor */ template -void CPUGather(const platform::DeviceContext& ctx, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); framework::DDim output_dims(src_dims); output_dims[0] = index_size; - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); // slice size diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu index 9937be5915..92219d6a43 100644 --- a/paddle/operators/gather_op.cu +++ b/paddle/operators/gather_op.cu @@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - GPUGather(ctx.device_context(), x, index, output); + GPUGather(ctx.device_context(), *x, *index, output); } }; @@ -52,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - GPUScatterAssign(ctx.device_context(), dO, Index, dX); + GPUScatterAssign(ctx.device_context(), *dO, *Index, dX); } }; diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index 5bd2c36f7b..8276ed0d3d 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - CPUGather(ctx.device_context(), x, index, output); + CPUGather(ctx.device_context(), *x, *index, output); } }; @@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - ScatterAssign(ctx.device_context(), dO, Index, dX); + ScatterAssign(ctx.device_context(), *dO, *Index, dX); } }; diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index d8bf8dd9a4..cbd86b8796 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -43,7 +43,7 @@ TEST(Gather, GatherData) { auto* cpu_place = new paddle::platform::CPUPlace(); paddle::platform::CPUDeviceContext ctx(*cpu_place); - CPUGather(ctx, src, index, output); + CPUGather(ctx, *src, *index, output); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); diff --git a/paddle/operators/scatter.cu.h b/paddle/operators/scatter.cu.h index f4a3965d94..d95436be4f 100644 --- a/paddle/operators/scatter.cu.h +++ b/paddle/operators/scatter.cu.h @@ -19,6 +19,8 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + #define CUDA_1D_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) @@ -45,16 +47,14 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices, * return: output tensor */ template -void GPUScatterAssign(const platform::DeviceContext& ctx, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); framework::DDim output_dims(src_dims); output_dims[0] = index_size; @@ -62,8 +62,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); int block = 512; diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h index 0d174d3b5b..c1fb844ebd 100644 --- a/paddle/operators/scatter.h +++ b/paddle/operators/scatter.h @@ -33,20 +33,18 @@ using Tensor = framework::Tensor; * return: output tensor */ template -void ScatterAssign(const platform::DeviceContext& ctx, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); auto dst_dims = output->dims(); - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); // check src shape and dst shape should match diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu index 6d13a876f9..06f4d75944 100644 --- a/paddle/operators/scatter_op.cu +++ b/paddle/operators/scatter_op.cu @@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel { Out->ShareDataWith(*Ref); - GPUScatterAssign(ctx.device_context(), Updates, Index, Out); + GPUScatterAssign(ctx.device_context(), *Updates, *Index, Out); } }; @@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Index] - GPUGather(ctx.device_context(), dOut, Index, dUpdates); + GPUGather(ctx.device_context(), *dOut, *Index, dUpdates); } }; diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index ac04968549..6101219006 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel { // In place output: Out = Ref, Out[Index] += Updates Out->ShareDataWith(*Ref); // Apply ScatterUpdate: Out[index] += Updates[:] - ScatterAssign(ctx.device_context(), Updates, Index, Out); + ScatterAssign(ctx.device_context(), *Updates, *Index, Out); } }; @@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - CPUGather(ctx.device_context(), dOut, Index, dUpdates); + CPUGather(ctx.device_context(), *dOut, *Index, dUpdates); } }; diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc index 321bba3dad..00dbdacbfe 100644 --- a/paddle/operators/scatter_test.cc +++ b/paddle/operators/scatter_test.cc @@ -42,7 +42,7 @@ TEST(scatter, ScatterUpdate) { auto* cpu_place = new paddle::platform::CPUPlace(); paddle::platform::CPUDeviceContext ctx(*cpu_place); - ScatterAssign(ctx, src, index, output); + ScatterAssign(ctx, *src, *index, output); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0)); -- GitLab