未验证 提交 fabc058b 编写于 作者: C Chen Weihang 提交者: GitHub

add copy constructor for densetensor (#38319)

上级 4d1ce184
......@@ -38,6 +38,9 @@ DenseTensor::DenseTensor(intrusive_ptr<Storage> storage,
DenseTensor::DenseTensor(intrusive_ptr<Storage> storage, DenseTensorMeta&& meta)
: meta_(std::move(meta)), storage_(std::move(storage)) {}
DenseTensor::DenseTensor(const DenseTensor& other)
: meta_(other.meta()), storage_(copy_intrusive(other.storage_)) {}
int64_t DenseTensor::numel() const {
if (meta_.is_scalar) {
return 1;
......
......@@ -62,9 +62,8 @@ class DenseTensor : public TensorBase,
/// move constructor to support move semantics.
DenseTensor(DenseTensor&& other) = default;
/// \brief We do not recommend deep copy of dense tensor because of its
/// efficiency and complexity across devices. The operation is disabled here.
DenseTensor(const DenseTensor& other) = delete;
/// \brief DenseTensor shallow copy constructor.
DenseTensor(const DenseTensor& other);
/// \brief Destroy the tensor object and release exclusive resources.
virtual ~DenseTensor() = default;
......
......@@ -38,10 +38,4 @@ bool DenseTensorMeta::valid() const noexcept {
return valid;
}
bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) {
bool ret = true;
return ret && (lhs.is_scalar == rhs.is_scalar) && (lhs.dims == rhs.dims) &&
(lhs.dtype == rhs.dtype) && (lhs.layout == rhs.layout) &&
(lhs.lod == rhs.lod) && (lhs.offset == rhs.offset);
}
} // namespace pten
......@@ -60,4 +60,11 @@ struct DenseTensorMeta {
size_t offset{0};
};
inline bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) {
bool ret = true;
return ret && (lhs.is_scalar == rhs.is_scalar) && (lhs.dims == rhs.dims) &&
(lhs.dtype == rhs.dtype) && (lhs.layout == rhs.layout) &&
(lhs.lod == rhs.lod) && (lhs.offset == rhs.offset);
}
} // namespace pten
......@@ -122,5 +122,20 @@ TEST(dense_tensor, resize) {
CHECK_EQ(storage->size(), 6u);
}
TEST(dense_tensor, shallow_copy) {
const DDim dims({1, 2});
const DataType dtype{DataType::INT8};
const DataLayout layout{DataLayout::NHWC};
const std::vector<std::vector<size_t>> lod{};
DenseTensorMeta meta(dtype, dims, layout, lod);
auto alloc = std::make_shared<FancyAllocator>();
DenseTensor tensor_0(alloc, meta);
DenseTensor tensor_1(tensor_0);
CHECK(tensor_0.meta() == tensor_1.meta());
CHECK(tensor_0.release() == tensor_1.release());
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册