提交 34beec0f 编写于 作者: F fengjiayi

update tensor.h

上级 8594d5c3
...@@ -33,7 +33,7 @@ class Tensor { ...@@ -33,7 +33,7 @@ class Tensor {
template <typename T> template <typename T>
const T* data() const { const T* data() const {
CheckDimsValidity<T>(); CheckDims<T>();
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
} }
...@@ -62,7 +62,7 @@ class Tensor { ...@@ -62,7 +62,7 @@ class Tensor {
template <typename T> template <typename T>
void ShareDataFrom(const Tensor& src) { void ShareDataFrom(const Tensor& src) {
src.CheckDimsValidity<T>(); src.CheckDims<T>();
holder_ = src.holder_; holder_ = src.holder_;
set_dims(src.dims()); set_dims(src.dims());
offset_ = src.offset_; offset_ = src.offset_;
...@@ -73,7 +73,7 @@ class Tensor { ...@@ -73,7 +73,7 @@ class Tensor {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place), platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now."); "Tensor::CopyFrom only support CPU now.");
src.CheckDimsValidity<T>(); src.CheckDims<T>();
size_t size = src.numel_ * sizeof(T); size_t size = src.numel_ * sizeof(T);
set_dims(src.dims()); set_dims(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>()); const void* src_ptr = static_cast<const void*>(src.data<T>());
...@@ -83,7 +83,7 @@ class Tensor { ...@@ -83,7 +83,7 @@ class Tensor {
template <typename T> template <typename T>
Tensor Slice(const int& begin_idx, const int& end_idx) const { Tensor Slice(const int& begin_idx, const int& end_idx) const {
CheckDimsValidity<T>(); CheckDims<T>();
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
"Slice index is less than zero or out of bound."); "Slice index is less than zero or out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx, PADDLE_ENFORCE(begin_idx < end_idx,
...@@ -109,7 +109,6 @@ class Tensor { ...@@ -109,7 +109,6 @@ class Tensor {
} }
dims_ = dims; dims_ = dims;
numel_ = product(dims_); numel_ = product(dims_);
return;
} }
DDim dims() const { return dims_; } DDim dims() const { return dims_; }
...@@ -155,10 +154,10 @@ class Tensor { ...@@ -155,10 +154,10 @@ class Tensor {
}; };
template <typename T> template <typename T>
inline void CheckDimsValidity() const { inline void CheckDims() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first."); "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() > numel_ * sizeof(T) + offset_, PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."); "first to re-allocate memory.");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册