提交 03d0040c 编写于 作者: Z zchen0211

gather warning fixed

上级 39f14f1d
......@@ -29,7 +29,7 @@ void CPUGather(const T* params, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) {
for (int i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
}
......@@ -60,7 +60,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
// slice size
int slice_size = 1;
for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering
if (platform::is_cpu_place(place)) {
......
......@@ -35,7 +35,7 @@ TEST(Gather, GatherData) {
p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace());
p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace());
for (size_t i = 0; i < 12; ++i) p_src[i] = i;
for (int i = 0; i < 12; ++i) p_src[i] = i;
p_index[0] = 1;
p_index[1] = 0;
......@@ -43,6 +43,6 @@ TEST(Gather, GatherData) {
Gather<int>(CPUPlace(), src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册