提交 68adb954 编写于 作者: F fengjiayi

enbale tensor memory test

上级 9b8451cc
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(tensor_test SRCS tensor_test.cc DEPS ddim) cc_test(tensor_test SRCS tensor_test.cc DEPS ddim paddle_memory)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc) cc_test(scope_test SRCS scope_test.cc)
cc_test(enforce_test SRCS enforce_test.cc) cc_test(enforce_test SRCS enforce_test.cc)
......
...@@ -29,8 +29,6 @@ class Tensor { ...@@ -29,8 +29,6 @@ class Tensor {
public: public:
Tensor() : numel_(0), offset_(0) {} Tensor() : numel_(0), offset_(0) {}
Tensor& operator=(const Tensor& src) = delete;
template <typename T> template <typename T>
const T* data() const { const T* data() const {
CheckDims<T>(); CheckDims<T>();
...@@ -39,13 +37,13 @@ class Tensor { ...@@ -39,13 +37,13 @@ class Tensor {
} }
template <typename T> template <typename T>
T* mutable_data(DDim dims, paddle::platform::Place place) { T* mutable_data(DDim dims, platform::Place place) {
set_dims(dims); set_dims(dims);
return mutable_data<T>(place); return mutable_data<T>(place);
} }
template <typename T> template <typename T>
T* mutable_data(paddle::platform::Place place) { T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(numel_ > 0, PADDLE_ENFORCE(numel_ > 0,
"Tensor::numel_ must be larger than zero to call " "Tensor::numel_ must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."); "Tensor::mutable_data. Call Tensor::set_dim first.");
...@@ -53,7 +51,18 @@ class Tensor { ...@@ -53,7 +51,18 @@ class Tensor {
!(holder_->place() == !(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */ place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) { || holder_->size() < numel_ * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T))); switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
}
offset_ = 0; offset_ = 0;
} }
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
...@@ -69,7 +78,7 @@ class Tensor { ...@@ -69,7 +78,7 @@ class Tensor {
} }
template <typename T> template <typename T>
void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) { void CopyFrom(const Tensor& src, platform::Place dst_place) {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place), platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now."); "Tensor::CopyFrom only support CPU now.");
...@@ -119,37 +128,36 @@ class Tensor { ...@@ -119,37 +128,36 @@ class Tensor {
struct Placeholder { struct Placeholder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
virtual void* ptr() const = 0; virtual void* ptr() const = 0;
virtual paddle::platform::Place place() const = 0; virtual platform::Place place() const = 0;
virtual size_t size() const = 0; virtual size_t size() const = 0;
}; };
template <typename T> template <typename T, typename PlaceType>
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
private: private:
template <typename PType>
class Deleter { class Deleter {
public: public:
Deleter(platform::Place place) : place_(place) {} Deleter(PType place) : place_(place) {}
void operator()(T* ptr) { void operator()(T* ptr) { memory::Free(place_, static_cast<void*>(ptr)); }
paddle::memory::Free(place_, static_cast<void*>(ptr));
}
private: private:
paddle::platform::Place place_; PType place_;
}; };
public: public:
PlaceholderImpl(paddle::platform::Place place, size_t size) PlaceholderImpl(PlaceType place, size_t size)
: ptr_(static_cast<T*>(paddle::memory::Alloc(place, size)), : ptr_(static_cast<T*>(memory::Alloc(place, size)),
Deleter(place)), Deleter<PlaceType>(place)),
place_(place), place_(place),
size_(size) {} size_(size) {}
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); } virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t size() const { return size_; } virtual size_t size() const { return size_; }
virtual paddle::platform::Place place() const { return place_; } virtual platform::Place place() const { return place_; }
std::unique_ptr<T, Deleter> ptr_; std::unique_ptr<T, Deleter<PlaceType>> ptr_;
paddle::platform::Place place_; // record the place of ptr_. platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block. size_t size_; // size of the memory block.
}; };
...@@ -166,7 +174,7 @@ class Tensor { ...@@ -166,7 +174,7 @@ class Tensor {
DDim dims_; DDim dims_;
size_t numel_; // cache of `product(dims_)` size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area. size_t offset_; // marks the begin of tensor data area.
}; }; // namespace framework
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -47,7 +47,7 @@ TEST(Tensor, DataAssert) { ...@@ -47,7 +47,7 @@ TEST(Tensor, DataAssert) {
/* following tests are not available at present /* following tests are not available at present
because Memory::Alloc() and Memory::Free() have not been ready. because Memory::Alloc() and Memory::Free() have not been ready.
*/
TEST(Tensor, MutableData) { TEST(Tensor, MutableData) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
...@@ -72,7 +72,7 @@ TEST(Tensor, MutableData) { ...@@ -72,7 +72,7 @@ TEST(Tensor, MutableData) {
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace()); p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
} }
/*
{ {
Tensor src_tensor; Tensor src_tensor;
float* p1 = nullptr; float* p1 = nullptr;
...@@ -94,6 +94,7 @@ TEST(Tensor, MutableData) { ...@@ -94,6 +94,7 @@ TEST(Tensor, MutableData) {
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace()); p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
} }
*/
} }
TEST(Tensor, ShareDataFrom) { TEST(Tensor, ShareDataFrom) {
...@@ -108,9 +109,11 @@ TEST(Tensor, ShareDataFrom) { ...@@ -108,9 +109,11 @@ TEST(Tensor, ShareDataFrom) {
dst_tensor.ShareDataFrom<float>(src_tensor); dst_tensor.ShareDataFrom<float>(src_tensor);
} catch (EnforceNotMet err) { } catch (EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data std::string msg =
first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); "Tenosr holds no memory. Call Tensor::mutable_data first.";
++i) { ASSERT_EQ(what[i], msg[i]); const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
} }
} }
ASSERT_TRUE(caught); ASSERT_TRUE(caught);
...@@ -120,6 +123,7 @@ first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ...@@ -120,6 +123,7 @@ first."; const char* what = err.what(); for (size_t i = 0; i < msg.length();
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>()); ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
} }
/*
{ {
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor; Tensor dst_tensor;
...@@ -127,6 +131,7 @@ first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ...@@ -127,6 +131,7 @@ first."; const char* what = err.what(); for (size_t i = 0; i < msg.length();
dst_tensor.ShareDataFrom<int>(src_tensor); dst_tensor.ShareDataFrom<int>(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>()); ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
} }
*/
} }
TEST(Tensor, Slice) { TEST(Tensor, Slice) {
...@@ -155,6 +160,7 @@ TEST(Tensor, Slice) { ...@@ -155,6 +160,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address); EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
} }
/*
{ {
Tensor src_tensor; Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace()); src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
...@@ -176,6 +182,7 @@ TEST(Tensor, Slice) { ...@@ -176,6 +182,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(slice_data_address, slice_mutable_data_address); EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
} }
*/
} }
TEST(Tensor, CopyFrom) { TEST(Tensor, CopyFrom) {
...@@ -203,4 +210,3 @@ TEST(Tensor, CopyFrom) { ...@@ -203,4 +210,3 @@ TEST(Tensor, CopyFrom) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]); EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
} }
} }
\ No newline at end of file
*/
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册