diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 10922892ca7ff90646f06ba74ee882535a889df3..a2284fc8f0f15ce742803e13efedbd766023582b 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,10 +41,7 @@ function(op_library TARGET) endif() endfunction() -op_library(gather SRCS gather_func.cc) -cc_test(gather_test SRCS gather_test.cc DEPS gather) - -op_library(scatter SRCS scatter_func.cc) +cc_test(gather_test SRCS gather_test.cc DEPS tensor) op_library(add_op SRCS add_op.cc add_op.cu) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) diff --git a/paddle/operators/gather_func.h b/paddle/operators/gather.h similarity index 100% rename from paddle/operators/gather_func.h rename to paddle/operators/gather.h diff --git a/paddle/operators/gather_func.cc b/paddle/operators/gather_func.cc deleted file mode 100644 index a6b2331f32a54d978fd42e6a3bd6ce6ab0fd8098..0000000000000000000000000000000000000000 --- a/paddle/operators/gather_func.cc +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/gather_func.h" -#include -#include "paddle/framework/ddim.h" -#include "paddle/framework/tensor.h" -#include "paddle/platform/place.h" diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 6f220b133bc1cac6a2081f2801da69bb578bf842..5d84b7b5f30e0d838896bb16b39e26d24bd916c1 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/operators/gather.h" #include "paddle/framework/ddim.h" #include "paddle/framework/tensor.h" -#include "paddle/operators/gather_func.h" #include "paddle/platform/place.h" #include #include #include -TEST(_abc_, GatherData) { +TEST(Gather, GatherData) { using namespace paddle::framework; using namespace paddle::platform; using namespace paddle::operators; diff --git a/paddle/operators/scatter_func.h b/paddle/operators/scatter_func.h deleted file mode 100644 index 53b260170fb42da6e759712475827a17b197bd69..0000000000000000000000000000000000000000 --- a/paddle/operators/scatter_func.h +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/framework/ddim.h" -#include "paddle/framework/tensor.h" -#include "paddle/platform/place.h" - -/** - * Return a updated tensor from source tensor, scattered according to index: - * dst[i] += src[index[i]] - * input[src]: type-T source Tensor - * input[index]: type-int index Tensor (1-D) - * return: output tensor - */ -template -void ScatterUpdate(Tensor* src, Tensor* dst, Tensor* index) { - // Source shape - auto src_dims = src->dims(); - auto dst_dims = dst->dims(); - DDim output_dims(dims_src); - - // check src shape and dst shape should match - for (size_t i = 1; i < src_dims.size(); i++) - PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); - - int index_size = index->dims()[0]; - - /* slice size */ - int slice_size = 1; - for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - - if (place == CPUPlace()) { - // init - output = new_tensor.mutable_data(output_dims, CPUPlace()); - CPUScatterUpdate( - src->data(), index->data(), slice_size, new_tensor->mutable_data()); - - } else { // GPU - // init - output = new_tensor.mutable_data(output_dims, GPUPlace()); - /* how to specialize device??*/ - GPUScatterUpdate( - d, src->data(), index->data(), slice_size, new_tensor->mutable_data()); - } -} - -/* Implementation of CPU copy */ -template -void CPUScatterUpdate(const T* src, - const int* index, - const int slice_size, - const int index_size, - T* output) { - // const size_t slice_bytes = slice_size * sizeof(T); - - for (size_t i = 0; i < index_size; ++i) { - int index_ = index[i]; - math::vAdd(slice_size, - src + index_ * slice_bytes, - output + i * slice_bytes, - output + i * slice_bytes); - } -} - -/* Implementation of GPU scatter: - I suppose the GPUDevice& d, contains gpu_id and thread_id - d = cuda_stream(gpu_id_, stream_id_); -*/ -template -void GPUScatterUpdate(const GPUDevice& d, - const T* src, - const int* index, - const int slice_size, - const int index_size, - T* output) { - int block_count = slice_size * index_size; - int thread_per_block = 1024; - - ScatterOpKernel<<>>( - src, index, output, slice_size, indices_size, slice_size, out_size); -} - -template -__global__ void ScatterOpKernel(const T* params, - const int* indices, - T* out, - int64 indices_size, - int64 slice_size, - int64 out_size) { - /* I suppose we have the following macro, - which I strongly suggest that we should put in cuda: - #define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - */ - CUDA_1D_KERNEL_LOOP(i, out_size) { - int indices_i = i / slice_size; - int slice_i = i - indices_i * slice_size; // offset inside the slice - int scatter_i = indices[indices_i]; - int params_i = scatter_i * slice_size + slice_i; - out[i] += *(params + params_i); - } -}