提交 6159f5db 编写于 作者: Z zchen0211

code style fix

上级 689d5ee1
......@@ -28,11 +28,8 @@ namespace operators {
/* Implementation of CPU copy */
template <typename T>
void CPUGather(const T* params,
const int* indices,
const int slice_size,
const int index_size,
T* output) {
void CPUGather(const T* params, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) {
......@@ -47,11 +44,8 @@ void CPUGather(const T* params,
d = cuda_stream(gpu_id_, stream_id_);
*/
template <typename T>
void GPUGather(const T* src,
const int* index,
const int slice_size,
const int index_size,
T* output);
void GPUGather(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
/**
* Return a new tensor from source tensor, gathered according to index
......@@ -60,8 +54,7 @@ void GPUGather(const T* src,
* return: output tensor
*/
template <typename T>
void Gather(const platform::Place& place,
const paddle::framework::Tensor* src,
void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
// check index of shape 1-D
......@@ -78,10 +71,7 @@ void Gather(const platform::Place& place,
// Gathering
if (platform::is_cpu_place(place)) {
CPUGather<T>(src->data<T>(),
index->data<int>(),
slice_size,
index_size,
CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>());
} else {
// init for GPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册