未验证 提交 c5ccff73 编写于 作者: 石晓伟 提交者: GitHub

supports the slice of upper tensor, test=develop (#37215)

上级 f49c2c23
......@@ -45,6 +45,39 @@ class CompatibleDenseTensorUtils {
static_cast<paddle::experimental::SharedStorage*>(tensor->storage_.get())
->Reset();
}
static DenseTensor Slice(DenseTensor* tensor,
int64_t begin_idx,
int64_t end_idx) {
tensor->check_memory_size();
PADDLE_ENFORCE_GE(begin_idx,
0,
paddle::platform::errors::OutOfRange(
"The start row index must be greater than 0."
"But received the start index is d%.",
begin_idx));
PADDLE_ENFORCE_LE(end_idx,
tensor->dims()[0],
paddle::platform::errors::OutOfRange(
"The end row index is out of bound."));
PADDLE_ENFORCE_LT(
begin_idx,
end_idx,
paddle::platform::errors::InvalidArgument(
"The start row index must be less than the end row index."
"But received the start index = %d, the end index = %d.",
begin_idx,
end_idx));
DenseTensor ret =
DenseTensor(copy_intrusive(tensor->storage_), tensor->meta_);
if (tensor->dims()[0] != 1) {
ret.meta_.dims[0] = end_idx - begin_idx;
ret.meta_.offset = tensor->meta_.offset +
begin_idx * (tensor->numel() / tensor->dims()[0]) *
paddle::experimental::SizeOf(tensor->data_type());
}
return ret;
}
};
} // namespace pten
......@@ -174,12 +174,6 @@ class DenseTensor : public TensorBase,
/// \return The const data pointer value of raw type.
const void* data() const;
/// \brief Get the shallow clone of current tensor.
/// \return The shallow clone of current tensor.
DenseTensor shallow_clone() const {
return DenseTensor(copy_intrusive(storage_), meta_);
}
private:
friend class CompatibleDenseTensorUtils;
......
......@@ -57,6 +57,7 @@ struct DenseTensorMeta {
const DataType type{DataType::UNDEFINED};
const DataLayout layout{DataLayout::NCHW};
LoD lod;
size_t offset{0};
};
inline DenseTensorMeta::DenseTensorMeta(DataType type, const DDim& dims)
......@@ -86,7 +87,7 @@ inline bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) {
bool ret = true;
return ret && (lhs.is_scalar == rhs.is_scalar) && (lhs.dims == rhs.dims) &&
(lhs.type == rhs.type) && (lhs.layout == rhs.layout) &&
(lhs.lod == rhs.lod);
(lhs.lod == rhs.lod) && (lhs.offset == rhs.offset);
}
} // namespace pten
......@@ -125,20 +125,5 @@ TEST(dense_tensor, resize) {
CHECK_EQ(storage->size(), 6u);
}
TEST(dense_tensor, shallow_clone) {
const DDim dims({1, 2});
const DataType dtype{DataType::INT8};
const DataLayout layout{DataLayout::NHWC};
const std::vector<std::vector<size_t>> lod{};
DenseTensorMeta meta(dtype, dims, layout, lod);
auto alloc = std::make_shared<FancyAllocator>();
DenseTensor tensor_0(alloc, meta);
auto tensor_1 = tensor_0.shallow_clone();
CHECK(tensor_0.meta() == tensor_1.meta());
CHECK(tensor_0.release() == tensor_1.release());
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册