From 2b35fca18f66e5f92315e369a687a5e908aedf1e Mon Sep 17 00:00:00 2001 From: Zhuoyuan <chenzhuoyuan07@gmail.com> Date: Wed, 2 Aug 2017 22:34:58 -0700 Subject: [PATCH] gather modify --- paddle/operators/gather_func.h | 71 ++++++++++++++++------------------ 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/paddle/operators/gather_func.h b/paddle/operators/gather_func.h index 09e751ce17a..e255bd7d151 100644 --- a/paddle/operators/gather_func.h +++ b/paddle/operators/gather_func.h @@ -21,44 +21,41 @@ limitations under the License. */ /** * Return a new tensor from source tensor, gathered according to index * input[src]: type-T source Tensor - * input[Index]: type-int index Tensor (1-D) + * input[index]: type-int index Tensor (1-D) * return: output tensor */ -template <typename place, typename T> -Tensor* Gather_func(Tensor* Src, Tensor* Index) { - // assert index is an int-type tensor? - // assert(Index->istype(int)); +template <typename Place, typename T> +Tensor* Gather(Tensor* src, Tensor* index) { + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size()==1); + int index_size = index->dims()[0]; - // check index of shape 1-D - assert(Index->dims().size()==1); - int index_size = Index->dims()[0]; + // Source shape + auto src_dims = src->dims(); + DDim output_dims(dims_src); + // Create a tensor of shape [index_size, dim_src[1:]] + output_dims[0] = index_size; - // Source shape - auto src_dims = Src->dims(); - DDim output_dims(dims_src); - // Create a tensor of shape [index_size, dim_src[1:]] - output_dims[0] = index_size; + Tensor* New_tensor; + float* output = nullptr; - Tensor* New_tensor; - float* output = nullptr; + /* slice size */ + int slice_size = 1; + for(unsigned int i = 0; i < src_dims.size(); ++i) + slice_size *= src_dims[i]; - /* slice size */ - int slice_size = 1; - for(unsigned int i = 0; i < src_dims.size(); ++i) - slice_size *= src_dims[i]; - - /* Gathering */ - if (place == CPUPlace()) { - // init for CPU - output = New_tensor.mutable_data<T>(output_dims, CPUPlace()); - CPUGather(Src->data(), Index->data(), slice_size, new_tensor->mutable_data()); - } else { // GPU - // init for GPU - output = New_tensor.mutable_data<T>(output_dims, GPUPlace()); - /* how to specialize device??*/ - GPUGather(d, Src->data(), Index->data(), slice_size, new_tensor->mutable_data()); - } - return New_tensor; + /* Gathering */ + if (place == CPUPlace()) { + // init for CPU + output = New_tensor.mutable_data<T>(output_dims, CPUPlace()); + CPUGather(src->data(), index->data(), slice_size, new_tensor->mutable_data()); + } else { // GPU + // init for GPU + output = New_tensor.mutable_data<T>(output_dims, GPUPlace()); + /* how to specialize device??*/ + GPUGather(d, src->data(), index->data(), slice_size, new_tensor->mutable_data()); + } + return New_tensor; } /* Implementation of CPU copy */ @@ -82,15 +79,15 @@ void CPUGather(const T* params, const int* indices, */ template<typename T> void GPUGather(const GPUDevice& d, - const T* src, const int* Index, + const T* src, const int* index, const int slice_size, const int index_size, T* output) { - int block_count = slice_size * index_size; - int thread_per_block = 1024; + int block_count = slice_size * index_size; + int thread_per_block = 1024; - GatherOpKernel<T> + GatherOpKernel<T> <<<block_count, thread_per_block, 0, d.stream()>>>( - src, Index, output, slice_size, + src, index, output, slice_size, indices_size, slice_size, out_size); } -- GitLab