提交 6159f5db 编写于 作者: Z zchen0211

code style fix

上级 689d5ee1
...@@ -28,11 +28,8 @@ namespace operators { ...@@ -28,11 +28,8 @@ namespace operators {
/* Implementation of CPU copy */ /* Implementation of CPU copy */
template <typename T> template <typename T>
void CPUGather(const T* params, void CPUGather(const T* params, const int* indices, const int slice_size,
const int* indices, const int index_size, T* output) {
const int slice_size,
const int index_size,
T* output) {
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) { for (size_t i = 0; i < index_size; ++i) {
...@@ -47,11 +44,8 @@ void CPUGather(const T* params, ...@@ -47,11 +44,8 @@ void CPUGather(const T* params,
d = cuda_stream(gpu_id_, stream_id_); d = cuda_stream(gpu_id_, stream_id_);
*/ */
template <typename T> template <typename T>
void GPUGather(const T* src, void GPUGather(const T* src, const int* index, const int slice_size,
const int* index, const int index_size, T* output);
const int slice_size,
const int index_size,
T* output);
/** /**
* Return a new tensor from source tensor, gathered according to index * Return a new tensor from source tensor, gathered according to index
...@@ -60,8 +54,7 @@ void GPUGather(const T* src, ...@@ -60,8 +54,7 @@ void GPUGather(const T* src,
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void Gather(const platform::Place& place, void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index, const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) { paddle::framework::Tensor* output) {
// check index of shape 1-D // check index of shape 1-D
...@@ -78,10 +71,7 @@ void Gather(const platform::Place& place, ...@@ -78,10 +71,7 @@ void Gather(const platform::Place& place,
// Gathering // Gathering
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
CPUGather<T>(src->data<T>(), CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size,
index->data<int>(),
slice_size,
index_size,
output->data<T>()); output->data<T>());
} else { } else {
// init for GPU // init for GPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册