From 696ba1d2e1f3fdac763c4dd29b5353b512f9b7fa Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 3 Jul 2017 16:01:50 +0800 Subject: [PATCH] init tensor_test.cc --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/tensor.h | 5 +-- paddle/framework/tensor_test.cc | 71 +++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 paddle/framework/tensor_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 6aa6b9bc2db..41bf3837aa2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -2,6 +2,7 @@ cc_library(ddim SRCS ddim.cc) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) +cc_test(tensor_test SRCS tensor_test.cc DEPS ddim) cc_test(variable_test SRCS variable_test.cc) cc_test(scope_test SRCS scope_test.cc) cc_test(enforce_test SRCS enforce_test.cc) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 067f2a85264..8d658d50972 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -19,13 +19,12 @@ namespace framework { class Tensor { using paddle::platform::Place; - using paddle::platform::get_place; public: template const T* data() const { - PADDLE_ASSERT(holder_ != nullptr, - "Tensor::data must be called after Tensor::mutable_data"); + PADDLE_ENFORCE(holder_ != nullptr, + "Tensor::data must be called after Tensor::mutable_data"); return static_cast(holder->Ptr()); } diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc new file mode 100644 index 00000000000..fa44b24b645 --- /dev/null +++ b/paddle/framework/tensor_test.cc @@ -0,0 +1,71 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include "paddle/framework/tensor.h" +#include + +TEST(Tensor, Data) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor cpu_tensor; +} + +/* mutable_data() is not tested at present + because Memory::Alloc() and Memory::Free() have not been ready. + +TEST(Tensor, MutableData) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor cpu_tensor; + float* p1 = nullptr; + float* p2 = nullptr; + // initialization + p1 = cpu_tensor.mutable_data(make_ddim({1, 2, 3}), CPUPlace()); + EXPECT_NE(p1, nullptr); + // set cpu_tensor a new dim with large size + // momery is supposed to be re-allocated + p2 = cpu_tensor.mutable_data(make_ddim({3, 4})); + EXPECT_NE(p2, nullptr); + EXPECT_NE(p1, p2); + // set cpu_tensor a new dim with same size + // momery block is supposed to be unchanged + p1 = cpu_tensor.mutable_data(make_ddim({2, 2, 3})); + EXPECT_EQ(p1, p2); + // set cpu_tensor a new dim with smaller size + // momery block is supposed to be unchanged + p2 = cpu_tensor.mutable_data(make_ddim({2, 2})); + EXPECT_EQ(p1, p2); + + Tensor gpu_tensor; + float* p1 = nullptr; + float* p2 = nullptr; + // initialization + p1 = gpu_tensor.mutable_data(make_ddim({1, 2, 3}), GPUPlace()); + EXPECT_NE(p1, nullptr); + // set gpu_tensor a new dim with large size + // momery is supposed to be re-allocated + p2 = gpu_tensor.mutable_data(make_ddim({3, 4})); + EXPECT_NE(p2, nullptr); + EXPECT_NE(p1, p2); + // set gpu_tensor a new dim with same size + // momery block is supposed to be unchanged + p1 = gpu_tensor.mutable_data(make_ddim({2, 2, 3})); + EXPECT_EQ(p1, p2); + // set gpu_tensor a new dim with smaller size + // momery block is supposed to be unchanged + p2 = gpu_tensor.mutable_data(make_ddim({2, 2})); + EXPECT_EQ(p1, p2); +} +*/ \ No newline at end of file -- GitLab