提交 5860150d 编写于 作者: H hedaoyuan

Fix Tensor::Slice with dims[0] == 1.

上级 db33ff12
...@@ -130,15 +130,19 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { ...@@ -130,15 +130,19 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx, PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index."); "Begin index must be less than end index.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
size_t base = numel() / dims_[0]; if (dims_[0] == 1) {
Tensor dst; return *this;
dst.holder_ = holder_; } else {
DDim dst_dims = dims_; size_t base = numel() / dims_[0];
dst_dims[0] = end_idx - begin_idx; Tensor dst;
dst.Resize(dst_dims); dst.holder_ = holder_;
dst.offset_ = offset_ + begin_idx * base * sizeof(T); DDim dst_dims = dims_;
return dst; dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
} }
inline Tensor& Tensor::Resize(const DDim& dims) { inline Tensor& Tensor::Resize(const DDim& dims) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册