提交 68adb954 编写于 作者: F fengjiayi

enbale tensor memory test

上级 9b8451cc
......@@ -2,7 +2,7 @@
cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(tensor_test SRCS tensor_test.cc DEPS ddim)
cc_test(tensor_test SRCS tensor_test.cc DEPS ddim paddle_memory)
cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)
cc_test(enforce_test SRCS enforce_test.cc)
......
......@@ -29,8 +29,6 @@ class Tensor {
public:
Tensor() : numel_(0), offset_(0) {}
Tensor& operator=(const Tensor& src) = delete;
template <typename T>
const T* data() const {
CheckDims<T>();
......@@ -39,13 +37,13 @@ class Tensor {
}
template <typename T>
T* mutable_data(DDim dims, paddle::platform::Place place) {
T* mutable_data(DDim dims, platform::Place place) {
set_dims(dims);
return mutable_data<T>(place);
}
template <typename T>
T* mutable_data(paddle::platform::Place place) {
T* mutable_data(platform::Place place) {
PADDLE_ENFORCE(numel_ > 0,
"Tensor::numel_ must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
......@@ -53,7 +51,18 @@ class Tensor {
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, numel_ * sizeof(T)));
switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
}
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
......@@ -69,7 +78,7 @@ class Tensor {
}
template <typename T>
void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) {
void CopyFrom(const Tensor& src, platform::Place dst_place) {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now.");
......@@ -119,37 +128,36 @@ class Tensor {
struct Placeholder {
virtual ~Placeholder() {}
virtual void* ptr() const = 0;
virtual paddle::platform::Place place() const = 0;
virtual platform::Place place() const = 0;
virtual size_t size() const = 0;
};
template <typename T>
template <typename T, typename PlaceType>
struct PlaceholderImpl : public Placeholder {
private:
template <typename PType>
class Deleter {
public:
Deleter(platform::Place place) : place_(place) {}
void operator()(T* ptr) {
paddle::memory::Free(place_, static_cast<void*>(ptr));
}
Deleter(PType place) : place_(place) {}
void operator()(T* ptr) { memory::Free(place_, static_cast<void*>(ptr)); }
private:
paddle::platform::Place place_;
PType place_;
};
public:
PlaceholderImpl(paddle::platform::Place place, size_t size)
: ptr_(static_cast<T*>(paddle::memory::Alloc(place, size)),
Deleter(place)),
PlaceholderImpl(PlaceType place, size_t size)
: ptr_(static_cast<T*>(memory::Alloc(place, size)),
Deleter<PlaceType>(place)),
place_(place),
size_(size) {}
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t size() const { return size_; }
virtual paddle::platform::Place place() const { return place_; }
virtual platform::Place place() const { return place_; }
std::unique_ptr<T, Deleter> ptr_;
paddle::platform::Place place_; // record the place of ptr_.
std::unique_ptr<T, Deleter<PlaceType>> ptr_;
platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
};
......@@ -166,7 +174,7 @@ class Tensor {
DDim dims_;
size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area.
};
}; // namespace framework
} // namespace framework
} // namespace paddle
......@@ -47,7 +47,7 @@ TEST(Tensor, DataAssert) {
/* following tests are not available at present
because Memory::Alloc() and Memory::Free() have not been ready.
*/
TEST(Tensor, MutableData) {
using namespace paddle::framework;
using namespace paddle::platform;
......@@ -72,7 +72,7 @@ TEST(Tensor, MutableData) {
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
EXPECT_EQ(p1, p2);
}
/*
{
Tensor src_tensor;
float* p1 = nullptr;
......@@ -94,6 +94,7 @@ TEST(Tensor, MutableData) {
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace());
EXPECT_EQ(p1, p2);
}
*/
}
TEST(Tensor, ShareDataFrom) {
......@@ -108,9 +109,11 @@ TEST(Tensor, ShareDataFrom) {
dst_tensor.ShareDataFrom<float>(src_tensor);
} catch (EnforceNotMet err) {
caught = true;
std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data
first."; const char* what = err.what(); for (size_t i = 0; i < msg.length();
++i) { ASSERT_EQ(what[i], msg[i]);
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(caught);
......@@ -120,6 +123,7 @@ first."; const char* what = err.what(); for (size_t i = 0; i < msg.length();
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
/*
{
Tensor src_tensor;
Tensor dst_tensor;
......@@ -127,6 +131,7 @@ first."; const char* what = err.what(); for (size_t i = 0; i < msg.length();
dst_tensor.ShareDataFrom<int>(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
*/
}
TEST(Tensor, Slice) {
......@@ -155,6 +160,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
}
/*
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
......@@ -176,6 +182,7 @@ TEST(Tensor, Slice) {
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
}
*/
}
TEST(Tensor, CopyFrom) {
......@@ -203,4 +210,3 @@ TEST(Tensor, CopyFrom) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
}
\ No newline at end of file
*/
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册