From 78808b20911dd95e1a49495c99d814b59e3290c9 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Thu, 28 Sep 2017 17:27:37 -0700 Subject: [PATCH] 1 api --- paddle/operators/cond_op.cc | 4 ++-- paddle/operators/gather.cu.h | 30 ++++++++++---------------- paddle/operators/gather.h | 38 +++++++++++++-------------------- paddle/operators/gather_op.cu | 4 ++-- paddle/operators/gather_op.h | 2 +- paddle/operators/gather_test.cc | 2 +- paddle/operators/scatter.cu.h | 36 ++++++++++++------------------- paddle/operators/scatter.h | 20 ++++++----------- paddle/operators/scatter_op.cu | 4 ++-- paddle/operators/scatter_op.h | 2 +- 10 files changed, 55 insertions(+), 87 deletions(-) diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 157656786ab..983b5142b1d 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -169,8 +169,8 @@ void CondOp::Run(const Scope& scope, tensor_child->Resize(dim); tensor_child->mutable_data(dim, platform::CPUPlace()); - CPUTGather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], - tensor_child); + CPUGather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], + tensor_child); } } diff --git a/paddle/operators/gather.cu.h b/paddle/operators/gather.cu.h index c96071e2955..b400c104407 100644 --- a/paddle/operators/gather.cu.h +++ b/paddle/operators/gather.cu.h @@ -38,19 +38,6 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, } } -// Implementation of GPU copy: -template -struct GPUGather { - void operator()(const T* src, const int* index, const int slice_size, - const int index_size, T* output) { - int block = 512; - int n = slice_size * index_size; - int grid = (n + block - 1) / block; - GatherCUDAKernel<<>>(src, index, output, index_size, - slice_size); - } -}; - /** * A thin wrapper on gpu tensor * Return a new tensor from source tensor, gathered according to index @@ -59,8 +46,8 @@ struct GPUGather { * return: output tensor */ template -void GPUTGather(const Place& place, const Tensor* src, const Tensor* index, - Tensor* output) { +void GPUGather(const Place& place, const Tensor* src, const Tensor* index, + Tensor* output) { PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); @@ -74,10 +61,15 @@ void GPUTGather(const Place& place, const Tensor* src, const Tensor* index, int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - // Gathering - GPUGather gather_functor; - gather_functor(src->data(), index->data(), slice_size, index_size, - output->data()); + const T* p_src = src->data(); + const int* p_index = index->data(); + T* p_output = output->data(); + + int block = 512; + int n = slice_size * index_size; + int grid = (n + block - 1) / block; + GatherCUDAKernel<<>>(p_src, p_index, p_output, index_size, + slice_size); } } // namespace operators diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index a3db17bd3dd..cb635f68255 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -24,32 +24,18 @@ limitations under the License. */ namespace paddle { namespace operators { -// Implementation of CPU copy -template -struct CPUGather { - void operator()(const T* src, const int* indices, const int slice_size, - const int index_size, T* output) { - const size_t slice_bytes = slice_size * sizeof(T); - - for (int i = 0; i < index_size; ++i) { - int index_ = indices[i]; - memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); - } - } -}; - /** - * A thin wrapper on cpu tensor + * A thin wrapper for gathering on cpu tensor * Return a new tensor from source tensor, gathered according to index * input[src]: type-T source Tensor * input[index]: type-int index Tensor (1-D) * return: output tensor */ template -void CPUTGather(const platform::Place& place, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void CPUGather(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); @@ -59,14 +45,20 @@ void CPUTGather(const platform::Place& place, framework::DDim output_dims(src_dims); output_dims[0] = index_size; + const T* p_src = src->data(); + const int* p_index = index->data(); + T* p_output = output->data(); + // slice size int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - // Gathering - CPUGather gather_functor; - gather_functor(src->data(), index->data(), slice_size, index_size, - output->data()); + const size_t slice_bytes = slice_size * sizeof(T); + + for (int i = 0; i < index_size; ++i) { + int index_ = p_index[i]; + memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); + } } } // namespace operators diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu index f7533cdd644..06004614b2c 100644 --- a/paddle/operators/gather_op.cu +++ b/paddle/operators/gather_op.cu @@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - GPUTGather(ctx.GetPlace(), x, index, output); + GPUGather(ctx.GetPlace(), x, index, output); } }; @@ -53,7 +53,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - GPUTScatter(ctx.GetPlace(), dO, Index, dX); + GPUScatterAssign(ctx.GetPlace(), dO, Index, dX); } }; diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index b80a4ab3705..fb065b8da7d 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - CPUTGather(ctx.GetPlace(), x, index, output); + CPUGather(ctx.GetPlace(), x, index, output); } }; diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index ea06ae28472..3c1d06ccd10 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -41,7 +41,7 @@ TEST(Gather, GatherData) { int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); - CPUTGather(CPUPlace(), src, index, output); + CPUGather(CPUPlace(), src, index, output); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); diff --git a/paddle/operators/scatter.cu.h b/paddle/operators/scatter.cu.h index 82e50403056..add4791a793 100644 --- a/paddle/operators/scatter.cu.h +++ b/paddle/operators/scatter.cu.h @@ -36,20 +36,6 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices, } } -// Implementation of GPU copy: -template -struct GPUScatterAssign { - void operator()(const T* src, const int* index, const int slice_size, - const int index_size, T* output) { - int block = 512; - int n = slice_size * index_size; - int grid = (n + block - 1) / block; - // printf("grid, block: %d %d\n", grid, block); - ScatterCUDAKernel<<>>(src, index, output, index_size, - slice_size); - } -}; - /** * A thin wrapper on gpu tensor * Return a new updated tensor from source tensor, scatter-assigned according to @@ -59,10 +45,10 @@ struct GPUScatterAssign { * return: output tensor */ template -void GPUTScatter(const platform::Place& place, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void GPUScatterAssign(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); @@ -76,10 +62,16 @@ void GPUTScatter(const platform::Place& place, int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - // Scatter Assign - GPUScatterAssign scatter_functor; - scatter_functor(src->data(), index->data(), slice_size, index_size, - output->data()); + const T* p_src = src->data(); + const int* p_index = index->data(); + T* p_output = output->data(); + + int block = 512; + int n = slice_size * index_size; + int grid = (n + block - 1) / block; + + ScatterCUDAKernel<<>>(p_src, p_index, p_output, index_size, + slice_size); } } // namespace operators diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h index 670204b4dd4..f895f22e281 100644 --- a/paddle/operators/scatter.h +++ b/paddle/operators/scatter.h @@ -25,19 +25,6 @@ namespace operators { using Tensor = framework::Tensor; -// Implementation of CPU copy -template -void CPUScatterAssign(const T* src, const int* index, const int slice_size, - const int index_size, T* output) { - // paddle::framework::DDim output_dims = output->dims(); - const size_t slice_bytes = slice_size * sizeof(T); - - for (int i = 0; i < index_size; ++i) { - int index_ = index[i]; - memcpy(output + index_ * slice_size, src + i * slice_size, slice_bytes); - } -} - /** * Return a updated tensor from source tensor, scattered according to index: * dst[i] = src[index[i]] @@ -70,7 +57,12 @@ void ScatterAssign(const platform::Place& place, size_t slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - CPUScatterAssign(p_src, p_index, slice_size, index_size, p_output); + const size_t slice_bytes = slice_size * sizeof(T); + + for (int i = 0; i < index_size; ++i) { + int index_ = p_index[i]; + memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes); + } } } // namespace operators diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu index 89d23945e0c..831eabdae4f 100644 --- a/paddle/operators/scatter_op.cu +++ b/paddle/operators/scatter_op.cu @@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel { Out->ShareDataWith(*Ref); - GPUTScatter(ctx.GetPlace(), Updates, Index, Out); + GPUScatterAssign(ctx.GetPlace(), Updates, Index, Out); } }; @@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Index] - GPUTGather(ctx.GetPlace(), dOut, Index, dUpdates); + GPUGather(ctx.GetPlace(), dOut, Index, dUpdates); } }; diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index 74b2718f433..771a1f2ddb7 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - CPUTGather(ctx.GetPlace(), dOut, Index, dUpdates); + CPUGather(ctx.GetPlace(), dOut, Index, dUpdates); } }; -- GitLab