提交 9029a9d9 编写于 作者: Y Yu Yang 提交者: Yi Wang

Fix constructor bug in mixed_vector (#8364)

* Fix constructor bug in mixed_vector

* Fix warnings

* Clean code

* Extract for-loop init. Make nvcc happy
上级 274f4e94
...@@ -37,9 +37,8 @@ class Vector { ...@@ -37,9 +37,8 @@ class Vector {
// Fill vector with value. The vector size is `count`. // Fill vector with value. The vector size is `count`.
explicit Vector(size_t count, const T& value = T()) { explicit Vector(size_t count, const T& value = T()) {
if (count == 0) { InitEmpty();
InitEmpty(); if (count != 0) {
} else {
resize(count); resize(count);
T* ptr = begin(); T* ptr = begin();
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
...@@ -122,6 +121,10 @@ class Vector { ...@@ -122,6 +121,10 @@ class Vector {
const T* begin() const { return &this->operator[](0); } const T* begin() const { return &this->operator[](0); }
const T* end() const { return &this->operator[](size()); } const T* end() const { return &this->operator[](size()); }
const T* cbegin() const { return begin(); }
const T* cend() const { return end(); }
const T& back() const { const T& back() const {
auto it = end(); auto it = end();
--it; --it;
...@@ -244,7 +247,9 @@ class Vector { ...@@ -244,7 +247,9 @@ class Vector {
bool operator==(const Vector<T>& other) const { bool operator==(const Vector<T>& other) const {
if (size() != other.size()) return false; if (size() != other.size()) return false;
for (auto it1 = begin(), it2 = other.begin(); it1 < end(); ++it1, ++it2) { auto it1 = cbegin();
auto it2 = other.cbegin();
for (; it1 < cend(); ++it1, ++it2) {
if (*it1 != *it2) { if (*it1 != *it2) {
return false; return false;
} }
......
...@@ -26,10 +26,10 @@ TEST(mixed_vector, CPU_VECTOR) { ...@@ -26,10 +26,10 @@ TEST(mixed_vector, CPU_VECTOR) {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
tmp.push_back(i); tmp.push_back(i);
} }
ASSERT_EQ(tmp.size(), 10); ASSERT_EQ(tmp.size(), 10UL);
vec<int> tmp2; vec<int> tmp2;
tmp2 = tmp; tmp2 = tmp;
ASSERT_EQ(tmp2.size(), 10); ASSERT_EQ(tmp2.size(), 10UL);
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp2[i], i); ASSERT_EQ(tmp2[i], i);
ASSERT_EQ(tmp2[i], tmp[i]); ASSERT_EQ(tmp2[i], tmp[i]);
...@@ -58,7 +58,7 @@ TEST(mixed_vector, GPU_VECTOR) { ...@@ -58,7 +58,7 @@ TEST(mixed_vector, GPU_VECTOR) {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
tmp.push_back(i); tmp.push_back(i);
} }
ASSERT_EQ(tmp.size(), 10); ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu(0); paddle::platform::CUDAPlace gpu(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu)>>>(tmp.MutableData(gpu)); multiply_10<<<1, 1, 0, GetCUDAStream(gpu)>>>(tmp.MutableData(gpu));
...@@ -79,7 +79,7 @@ TEST(mixed_vector, MultiGPU) { ...@@ -79,7 +79,7 @@ TEST(mixed_vector, MultiGPU) {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
tmp.push_back(i); tmp.push_back(i);
} }
ASSERT_EQ(tmp.size(), 10); ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu0(0); paddle::platform::CUDAPlace gpu0(0);
paddle::platform::SetDeviceId(0); paddle::platform::SetDeviceId(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0)); multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0));
...@@ -91,3 +91,10 @@ TEST(mixed_vector, MultiGPU) { ...@@ -91,3 +91,10 @@ TEST(mixed_vector, MultiGPU) {
ASSERT_EQ(tmp[i], i * 100); ASSERT_EQ(tmp[i], i * 100);
} }
} }
TEST(mixed_vector, InitWithCount) {
paddle::framework::Vector<int> vec(10, 10);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(vec[i], 10);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册