未验证 提交 d89f246c 编写于 作者: Z zhangkaihuo 提交者: GitHub

implement AllocateFrom (#39280)

上级 35f949b5
......@@ -47,6 +47,12 @@ SparseCooTensor SparseCooTensor::operator=(const SparseCooTensor& other) {
return *this;
}
void* SparseCooTensor::AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size) {
return non_zero_elements_.AllocateFrom(allocator, dtype, requested_size);
}
int64_t SparseCooTensor::nnz() const {
const auto indices_dims = non_zero_indices_.dims();
if (indices_dims.size() == 0) {
......
......@@ -141,6 +141,11 @@ class SparseCooTensor : public TensorBase,
/// return a mutable pointer of non_zero_elements.
DenseTensor* mutable_non_zero_elements() { return &non_zero_elements_; }
/// \brief This function is not recommended
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0) override;
private:
// save the indices of non zero elements in original dense tensor
DenseTensor non_zero_indices_;
......
......@@ -67,6 +67,12 @@ SparseCsrTensor& SparseCsrTensor::operator=(const SparseCsrTensor& other) {
return *this;
}
void* SparseCsrTensor::AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size) {
return non_zero_elements_.AllocateFrom(allocator, dtype, requested_size);
}
void SparseCsrTensor::Resize(const DDim& dense_dims,
const int64_t non_zero_num) {
PADDLE_ENFORCE(this->initialized(),
......
......@@ -59,6 +59,11 @@ class SparseCsrTensor : public TensorBase,
/// \brief Destroy the tensor object and release exclusive resources.
virtual ~SparseCsrTensor() = default;
/// \brief This function is not recommended
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0) override;
public:
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
......
......@@ -24,11 +24,4 @@ void Copy(const Context& dev_ctx,
const DenseTensor& src,
bool blocking,
DenseTensor* dst);
template <typename Context>
void CopySparse(const Context& dev_ctx,
const SparseCsrTensor& src,
bool blocking,
SparseCsrTensor* dst);
} // namespace pten
......@@ -215,25 +215,7 @@ void Copy(const Context& dev_ctx,
}
}
template <typename Context>
void CopySparse(const Context& dev_ctx,
const SparseCsrTensor& src,
bool blocking,
SparseCsrTensor* dst) {
Copy(dev_ctx, src.non_zero_crows(), blocking, dst->mutable_non_zero_crows());
Copy(dev_ctx, src.non_zero_cols(), blocking, dst->mutable_non_zero_cols());
Copy(dev_ctx,
src.non_zero_elements(),
blocking,
dst->mutable_non_zero_elements());
}
} // namespace pten
PT_REGISTER_GENERAL_KERNEL(
copy, GPU, ALL_LAYOUT, pten::Copy<pten::GPUContext>, ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(copy_sparse,
GPU,
ALL_LAYOUT,
pten::CopySparse<pten::GPUContext>,
ALL_DTYPE) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册