提交 34beec0f 编写于 作者: F fengjiayi

update tensor.h

上级 8594d5c3
......@@ -33,7 +33,7 @@ class Tensor {
template <typename T>
const T* data() const {
CheckDimsValidity<T>();
CheckDims<T>();
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
......@@ -62,7 +62,7 @@ class Tensor {
template <typename T>
void ShareDataFrom(const Tensor& src) {
src.CheckDimsValidity<T>();
src.CheckDims<T>();
holder_ = src.holder_;
set_dims(src.dims());
offset_ = src.offset_;
......@@ -73,7 +73,7 @@ class Tensor {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now.");
src.CheckDimsValidity<T>();
src.CheckDims<T>();
size_t size = src.numel_ * sizeof(T);
set_dims(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>());
......@@ -83,7 +83,7 @@ class Tensor {
template <typename T>
Tensor Slice(const int& begin_idx, const int& end_idx) const {
CheckDimsValidity<T>();
CheckDims<T>();
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
"Slice index is less than zero or out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
......@@ -109,7 +109,6 @@ class Tensor {
}
dims_ = dims;
numel_ = product(dims_);
return;
}
DDim dims() const { return dims_; }
......@@ -155,10 +154,10 @@ class Tensor {
};
template <typename T>
inline void CheckDimsValidity() const {
inline void CheckDims() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() > numel_ * sizeof(T) + offset_,
PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册