From 5860150d96eefc11f55fe9e8408734001ab0483c Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 10:44:53 +0800 Subject: [PATCH] Fix Tensor::Slice with dims[0] == 1. --- paddle/framework/tensor_impl.h | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 642b53efc7..3fcbc5447f 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -130,15 +130,19 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE_LT(begin_idx, end_idx, "Begin index must be less than end index."); - PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1."); - size_t base = numel() / dims_[0]; - Tensor dst; - dst.holder_ = holder_; - DDim dst_dims = dims_; - dst_dims[0] = end_idx - begin_idx; - dst.Resize(dst_dims); - dst.offset_ = offset_ + begin_idx * base * sizeof(T); - return dst; + + if (dims_[0] == 1) { + return *this; + } else { + size_t base = numel() / dims_[0]; + Tensor dst; + dst.holder_ = holder_; + DDim dst_dims = dims_; + dst_dims[0] = end_idx - begin_idx; + dst.Resize(dst_dims); + dst.offset_ = offset_ + begin_idx * base * sizeof(T); + return dst; + } } inline Tensor& Tensor::Resize(const DDim& dims) { -- GitLab