diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 8b021565450a7f2ac80a32fd4fc8cb11fd29e54a..0c73717d38aca9f3430e66cafc3ecccdd2eec776 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -20,13 +20,10 @@ limitations under the License. */ #include "paddle/framework/tensor.h" #include "paddle/platform/place.h" -using paddle::framework::Tensor; -using paddle::framework::DDim; - namespace paddle { namespace operators { -/* Implementation of CPU copy */ +// Implementation of CPU copy template void CPUGather(const T* params, const int* indices, const int slice_size, const int index_size, T* output) { @@ -34,15 +31,11 @@ void CPUGather(const T* params, const int* indices, const int slice_size, for (size_t i = 0; i < index_size; ++i) { int index_ = indices[i]; - // copy src[index_] to output[i] memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); } } -/* Implementation of GPU copy: - I suppose the GPUDevice& d, contains gpu_id and thread_id - d = cuda_stream(gpu_id_, stream_id_); -*/ +// Implementation of GPU copy: template void GPUGather(const T* src, const int* index, const int slice_size, const int index_size, T* output); @@ -62,7 +55,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, int index_size = index->dims()[0]; auto src_dims = src->dims(); - DDim output_dims(src_dims); + paddle::framework::DDim output_dims(src_dims); output_dims[0] = index_size; // slice size @@ -73,13 +66,6 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, if (platform::is_cpu_place(place)) { CPUGather(src->data(), index->data(), slice_size, index_size, output->data()); - } else { - // init for GPU - // output_arr = output->mutable_data(output_dims, platform::GPUPlace()); - // how to specialize device?? - // GPUGather( - // d, src->data(), index->data(), slice_size, - // new_tensor->mutable_data()); } } diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 5d84b7b5f30e0d838896bb16b39e26d24bd916c1..5de748ec461e4b1a34b75b57c9cd7d5bc9326059 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -29,7 +29,6 @@ TEST(Gather, GatherData) { Tensor* src = new Tensor(); Tensor* index = new Tensor(); Tensor* output = new Tensor(); - // src.Resize(make_ddim({3, 4})); int* p_src = nullptr; int* p_index = nullptr; @@ -40,7 +39,6 @@ TEST(Gather, GatherData) { p_index[0] = 1; p_index[1] = 0; - // gather int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); Gather(CPUPlace(), src, index, output);