diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 4409c6feae218222b7c0216760cebe4ae8e235cb..2d26a62d0f0bc23d9d4f4ad7233fdceff3585023 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,5 +1,5 @@ # ddim lib -cc_library(ddim SRCS ddim.cc) +cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(tensor_test SRCS tensor_test.cc DEPS ddim) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 3fd3e538e8da41bb6fbf533dad3789688a734e0d..fe8f79abd4fe5d94a8805fa2ddcd8103706dd083 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -222,9 +222,9 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { } template -Eigen::DSizes ToEigenDSizes(DDim dims) const { - int rank = paddle::framework::arity(dims); - PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same") +Eigen::DSizes ToEigenDSizes(const DDim& dims) { + int rank = arity(dims); + PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); Eigen::DSizes dsizes; for (int d = 0; d < rank; d++) { dsizes[d] = dims[d]; diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index a83a367196d7cbfaf3aadb2c9be2eefdde99267b..18395c3636cb710901d13b3660ac81a73270e1cd 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -93,7 +93,7 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); template -Eigen::DSizes ToEigenDSizes(DDim dims) const; +Eigen::DSizes ToEigenDSizes(const DDim& dims); } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 0fa74e7ab1331b2fdefe4f85407f382c985c5164..21818937e809b5559d3172409aa492f88834c8f0 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -28,13 +28,6 @@ namespace framework { class Tensor { public: - template - const T* data() const { - PADDLE_ENFORCE(holder_ != nullptr, - "Tensor::data must be called after Tensor::mutable_data."); - return static_cast(holder_->Ptr()); - } - template T* data() const { PADDLE_ENFORCE(holder_ != nullptr, @@ -60,14 +53,14 @@ class Tensor { size_t NumElements() const { return product(dims_); } template - typename TTypes::Tensor Tensor::shaped(DDim new_dims) { + typename TTypes::Tensor shaped(DDim new_dims) { Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); + paddle::framework::ToEigenDSizes(new_dims); return typename TTypes::Tensor(data(), dims); } template - typename TTypes::Tensor Tensor::tensor() { + typename TTypes::Tensor tensor() { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); } @@ -92,7 +85,7 @@ class Tensor { // const versions of all the methods above. template - typename TTypes::ConstantTensor Tensor::tensor() const { + typename TTypes::ConstantTensor tensor() const { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); }