提交 d4e4cebf 编写于 作者: Z zchen0211

fix all coding-style problems

上级 6159f5db
......@@ -20,13 +20,10 @@ limitations under the License. */
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
using paddle::framework::Tensor;
using paddle::framework::DDim;
namespace paddle {
namespace operators {
/* Implementation of CPU copy */
// Implementation of CPU copy
template <typename T>
void CPUGather(const T* params, const int* indices, const int slice_size,
const int index_size, T* output) {
......@@ -34,15 +31,11 @@ void CPUGather(const T* params, const int* indices, const int slice_size,
for (size_t i = 0; i < index_size; ++i) {
int index_ = indices[i];
// copy src[index_] to output[i]
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
}
}
/* Implementation of GPU copy:
I suppose the GPUDevice& d, contains gpu_id and thread_id
d = cuda_stream(gpu_id_, stream_id_);
*/
// Implementation of GPU copy:
template <typename T>
void GPUGather(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
......@@ -62,7 +55,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
int index_size = index->dims()[0];
auto src_dims = src->dims();
DDim output_dims(src_dims);
paddle::framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
......@@ -73,13 +66,6 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
if (platform::is_cpu_place(place)) {
CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>());
} else {
// init for GPU
// output_arr = output->mutable_data<T>(output_dims, platform::GPUPlace());
// how to specialize device??
// GPUGather(
// d, src->data(), index->data(), slice_size,
// new_tensor->mutable_data());
}
}
......
......@@ -29,7 +29,6 @@ TEST(Gather, GatherData) {
Tensor* src = new Tensor();
Tensor* index = new Tensor();
Tensor* output = new Tensor();
// src.Resize(make_ddim({3, 4}));
int* p_src = nullptr;
int* p_index = nullptr;
......@@ -40,7 +39,6 @@ TEST(Gather, GatherData) {
p_index[0] = 1;
p_index[1] = 0;
// gather
int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());
Gather<int>(CPUPlace(), src, index, output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册