提交 a5eb1d8f 编写于 作者: Q qijun

fix build error

上级 d607f0b7
# ddim lib # ddim lib
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(tensor_test SRCS tensor_test.cc DEPS ddim) cc_test(tensor_test SRCS tensor_test.cc DEPS ddim)
......
...@@ -222,9 +222,9 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { ...@@ -222,9 +222,9 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
} }
template <int NDIMS> template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(DDim dims) const { Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
int rank = paddle::framework::arity(dims); int rank = arity(dims);
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same") PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < rank; d++) { for (int d = 0; d < rank; d++) {
dsizes[d] = dims[d]; dsizes[d] = dims[d];
......
...@@ -93,7 +93,7 @@ int arity(const DDim& ddim); ...@@ -93,7 +93,7 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&); std::ostream& operator<<(std::ostream&, const DDim&);
template <int NDIMS> template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(DDim dims) const; Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -28,13 +28,6 @@ namespace framework { ...@@ -28,13 +28,6 @@ namespace framework {
class Tensor { class Tensor {
public: public:
template <typename T>
const T* data() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tensor::data must be called after Tensor::mutable_data.");
return static_cast<const T*>(holder_->Ptr());
}
template <typename T> template <typename T>
T* data() const { T* data() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(holder_ != nullptr,
...@@ -60,14 +53,14 @@ class Tensor { ...@@ -60,14 +53,14 @@ class Tensor {
size_t NumElements() const { return product(dims_); } size_t NumElements() const { return product(dims_); }
template <typename T, size_t NDIMS> template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::shaped(DDim new_dims) { typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
Eigen::array<Eigen::DenseIndex, NDIMS> dims = Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes(new_dims); paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims); return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
} }
template <typename T, size_t NDIMS> template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() { typename TTypes<T, NDIMS>::Tensor tensor() {
return typename TTypes<T, NDIMS>::Tensor( return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_)); data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
} }
...@@ -92,7 +85,7 @@ class Tensor { ...@@ -92,7 +85,7 @@ class Tensor {
// const versions of all the methods above. // const versions of all the methods above.
template <typename T, size_t NDIMS> template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstantTensor Tensor::tensor() const { typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
return typename TTypes<T, NDIMS>::Tensor( return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_)); data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册