提交 a5eb1d8f 编写于 作者: Q qijun

fix build error

上级 d607f0b7
# ddim lib
cc_library(ddim SRCS ddim.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(tensor_test SRCS tensor_test.cc DEPS ddim)
......
......@@ -222,9 +222,9 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
}
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(DDim dims) const {
int rank = paddle::framework::arity(dims);
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same")
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
int rank = arity(dims);
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < rank; d++) {
dsizes[d] = dims[d];
......
......@@ -93,7 +93,7 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(DDim dims) const;
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims);
} // namespace framework
} // namespace paddle
......
......@@ -28,13 +28,6 @@ namespace framework {
class Tensor {
public:
template <typename T>
const T* data() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tensor::data must be called after Tensor::mutable_data.");
return static_cast<const T*>(holder_->Ptr());
}
template <typename T>
T* data() const {
PADDLE_ENFORCE(holder_ != nullptr,
......@@ -60,14 +53,14 @@ class Tensor {
size_t NumElements() const { return product(dims_); }
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::shaped(DDim new_dims) {
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes(new_dims);
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
typename TTypes<T, NDIMS>::Tensor tensor() {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
......@@ -92,7 +85,7 @@ class Tensor {
// const versions of all the methods above.
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstantTensor Tensor::tensor() const {
typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册