You need to sign in or sign up before continuing.
未验证 提交 2d21aa76 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #12331 from jacquesqiao/fix-mixed-tensor

fix mixed tensor compile and add cpu unit test
......@@ -22,7 +22,12 @@ endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place memory device_context tensor)
if(WITH_GPU)
nv_test(mixed_vector_test SRCS mixed_vector_test.cc mixed_vector_test.cu DEPS place memory device_context tensor)
else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
......
......@@ -16,6 +16,7 @@
#include <algorithm>
#include <initializer_list>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
......@@ -386,13 +387,14 @@ template <typename T>
class CPUVector : public std::vector<T, std::allocator<T>> {
public:
CPUVector() : std::vector<T>() {}
CPUVector(size_t count, const T &value = T())
CPUVector(size_t count, const T &value = T()) // NOLINT
: std::vector<T>(count, value) {}
CPUVector(std::initializer_list<T> init) : std::vector<T>(init) {}
CPUVector(const std::vector<T> &other) : std::vector<T>(other) {}
explicit CPUVector(const CPUVector<T> &other) : std::vector<T>(other) {}
CPUVector(const std::vector<T> &other) : std::vector<T>(other) {} // NOLINT
CPUVector(const CPUVector<T> &other) : std::vector<T>(other) {}
CPUVector(CPUVector<T> &&other) : std::vector<T>(std::move(other)) {}
CPUVector(std::vector<T> &&other) : std::vector<T>(std::move(other)) {}
CPUVector(std::vector<T> &&other) // NOLINT
: std::vector<T>(std::move(other)) {}
CPUVector &operator=(const CPUVector &other) {
this->assign(other.begin(), other.end());
return *this;
......@@ -410,8 +412,6 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
return os;
}
void resize(size_t size) { this->resize(size); }
T &operator[](size_t id) { return this->at(id); }
const T &operator[](size_t id) const { return this->at(id); }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h"
template <typename T>
using vec = paddle::framework::Vector<T>;
TEST(mixed_vector, CPU_VECTOR) {
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
vec<int> tmp2;
tmp2 = tmp;
ASSERT_EQ(tmp2.size(), 10UL);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp2[i], i);
ASSERT_EQ(tmp2[i], tmp[i]);
}
int cnt = 0;
for (auto& t : tmp2) {
ASSERT_EQ(t, cnt);
++cnt;
}
}
TEST(mixed_vector, InitWithCount) {
paddle::framework::Vector<int> vec(10, 10);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(vec[i], 10);
}
}
TEST(mixed_vector, ForEach) {
vec<int> tmp;
for (auto& v : tmp) {
VLOG(3) << v;
}
}
TEST(mixed_vector, Reserve) {
paddle::framework::Vector<int> vec;
vec.reserve(1);
vec.push_back(0);
vec.push_back(0);
vec.push_back(0);
}
TEST(mixed_vector, Resize) {
paddle::framework::Vector<int> vec;
vec.resize(1);
vec.push_back(0);
vec.push_back(0);
vec.push_back(0);
}
......@@ -11,7 +11,9 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
......@@ -21,26 +23,6 @@
template <typename T>
using vec = paddle::framework::Vector<T>;
TEST(mixed_vector, CPU_VECTOR) {
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
vec<int> tmp2;
tmp2 = tmp;
ASSERT_EQ(tmp2.size(), 10UL);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp2[i], i);
ASSERT_EQ(tmp2[i], tmp[i]);
}
int cnt = 0;
for (auto& t : tmp2) {
ASSERT_EQ(t, cnt);
++cnt;
}
}
static __global__ void multiply_10(int* ptr) {
for (int i = 0; i < 10; ++i) {
ptr[i] *= 10;
......@@ -91,24 +73,3 @@ TEST(mixed_vector, MultiGPU) {
ASSERT_EQ(tmp[i], i * 100);
}
}
TEST(mixed_vector, InitWithCount) {
paddle::framework::Vector<int> vec(10, 10);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(vec[i], 10);
}
}
TEST(mixed_vector, ForEach) {
vec<int> tmp;
for (auto& v : tmp) {
}
}
TEST(mixed_vector, Reserve) {
paddle::framework::Vector<int> vec;
vec.reserve(1);
vec.push_back(0);
vec.push_back(0);
vec.push_back(0);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册