diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index b910bee836ed488aeb34f28d0503b5efba396583..10922892ca7ff90646f06ba74ee882535a889df3 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,6 +41,11 @@ function(op_library TARGET) endif() endfunction() +op_library(gather SRCS gather_func.cc) +cc_test(gather_test SRCS gather_test.cc DEPS gather) + +op_library(scatter SRCS scatter_func.cc) + op_library(add_op SRCS add_op.cc add_op.cu) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) diff --git a/paddle/operators/gather_func.cc b/paddle/operators/gather_func.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6b2331f32a54d978fd42e6a3bd6ce6ab0fd8098 --- /dev/null +++ b/paddle/operators/gather_func.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/gather_func.h" +#include +#include "paddle/framework/ddim.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" diff --git a/paddle/operators/gather_func.h b/paddle/operators/gather_func.h index 5975675cbbdcc4ccbe15f9543da81e2f19cf0051..5adc1e6b179139729602a878337192867d8ff2c9 100644 --- a/paddle/operators/gather_func.h +++ b/paddle/operators/gather_func.h @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -13,51 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include + #include "paddle/framework/ddim.h" #include "paddle/framework/tensor.h" #include "paddle/platform/place.h" -/** - * Return a new tensor from source tensor, gathered according to index - * input[src]: type-T source Tensor - * input[index]: type-int index Tensor (1-D) - * return: output tensor - */ -template -Tensor* Gather(Tensor* src, Tensor* index) { - // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; - - // Source shape - auto src_dims = src->dims(); - DDim output_dims(dims_src); - // Create a tensor of shape [index_size, dim_src[1:]] - output_dims[0] = index_size; - - Tensor* New_tensor; - float* output = nullptr; - - /* slice size */ - int slice_size = 1; - for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; +using paddle::framework::Tensor; +using paddle::framework::DDim; - /* Gathering */ - if (place == CPUPlace()) { - // init for CPU - output = New_tensor.mutable_data(output_dims, CPUPlace()); - CPUGather( - src->data(), index->data(), slice_size, new_tensor->mutable_data()); - } else { // GPU - // init for GPU - output = New_tensor.mutable_data(output_dims, GPUPlace()); - /* how to specialize device??*/ - GPUGather( - d, src->data(), index->data(), slice_size, new_tensor->mutable_data()); - } - return New_tensor; -} +namespace paddle { +namespace operators { /* Implementation of CPU copy */ template @@ -70,48 +37,61 @@ void CPUGather(const T* params, for (size_t i = 0; i < index_size; ++i) { int index_ = indices[i]; - /* copy src[index_] to output[i] */ - memcpy( - output + i * slice_bytes, params + index_ * slice_bytes, slice_bytes); + // copy src[index_] to output[i] + memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); } } /* Implementation of GPU copy: - I suppose the GPUDevice& d, contains gpu_id and thread_id - d = cuda_stream(gpu_id_, stream_id_); + I suppose the GPUDevice& d, contains gpu_id and thread_id + d = cuda_stream(gpu_id_, stream_id_); */ template -void GPUGather(const GPUDevice& d, - const T* src, +void GPUGather(const T* src, const int* index, const int slice_size, const int index_size, - T* output) { - int block_count = slice_size * index_size; - int thread_per_block = 1024; - - GatherOpKernel<<>>( - src, index, output, slice_size, indices_size, slice_size, out_size); -} + T* output); +/** + * Return a new tensor from source tensor, gathered according to index + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ template -__global__ void GatherOpKernel(const T* params, - const int* indices, - T* out, - int64 indices_size, - int64 slice_size, - int64 out_size) { - /* I suppose we have the following macro, - which I strongly suggest that we should put in cuda: - #define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - */ - CUDA_1D_KERNEL_LOOP(i, out_size) { - int indices_i = i / slice_size; - int slice_i = i - indices_i * slice_size; // offset inside the slice - int gather_i = indices[indices_i]; - int params_i = gather_i * slice_size + slice_i; - out[i] = *(params + params_i); +void Gather(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size() == 1); + int index_size = index->dims()[0]; + + auto src_dims = src->dims(); + DDim output_dims(src_dims); + output_dims[0] = index_size; + + // slice size + int slice_size = 1; + for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + // Gathering + if (platform::is_cpu_place(place)) { + CPUGather(src->data(), + index->data(), + slice_size, + index_size, + output->data()); + } else { + // init for GPU + // output_arr = output->mutable_data(output_dims, platform::GPUPlace()); + // how to specialize device?? + // GPUGather( + // d, src->data(), index->data(), slice_size, + // new_tensor->mutable_data()); } } + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f220b133bc1cac6a2081f2801da69bb578bf842 --- /dev/null +++ b/paddle/operators/gather_test.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/ddim.h" +#include "paddle/framework/tensor.h" +#include "paddle/operators/gather_func.h" +#include "paddle/platform/place.h" + +#include +#include +#include + +TEST(_abc_, GatherData) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators; + + Tensor* src = new Tensor(); + Tensor* index = new Tensor(); + Tensor* output = new Tensor(); + // src.Resize(make_ddim({3, 4})); + + int* p_src = nullptr; + int* p_index = nullptr; + p_src = src->mutable_data(make_ddim({3, 4}), CPUPlace()); + p_index = index->mutable_data(make_ddim({2}), CPUPlace()); + + for (size_t i = 0; i < 12; ++i) p_src[i] = i; + p_index[0] = 1; + p_index[1] = 0; + + // gather + int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); + + Gather(CPUPlace(), src, index, output); + + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); + for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); +}