提交 5860150d 编写于 作者: H hedaoyuan

Fix Tensor::Slice with dims[0] == 1.

上级 db33ff12
...@@ -130,7 +130,10 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { ...@@ -130,7 +130,10 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx, PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index."); "Begin index must be less than end index.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
if (dims_[0] == 1) {
return *this;
} else {
size_t base = numel() / dims_[0]; size_t base = numel() / dims_[0];
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
...@@ -139,6 +142,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { ...@@ -139,6 +142,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
dst.Resize(dst_dims); dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T); dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst; return dst;
}
} }
inline Tensor& Tensor::Resize(const DDim& dims) { inline Tensor& Tensor::Resize(const DDim& dims) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册