From f5ae67e8799bea7ef12c9086953da37b1edf050e Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Fri, 7 Apr 2023 15:26:28 +0800 Subject: [PATCH] Isolate DenseTensor::set_type and DenseTensor::set_layout from header file (#52591) * Isolate DenseTensor::set_type from header file * Fix selected_rows --- paddle/phi/core/selected_rows.cc | 6 ++++++ paddle/phi/core/selected_rows.h | 8 ++++++-- paddle/phi/core/selected_rows_impl.cc | 9 +++++++++ paddle/phi/core/selected_rows_impl.h | 8 ++++++-- paddle/phi/core/sparse_coo_tensor.cc | 6 ++++++ paddle/phi/core/sparse_coo_tensor.h | 9 +++++++-- paddle/phi/core/sparse_csr_tensor.cc | 6 ++++++ paddle/phi/core/sparse_csr_tensor.h | 8 ++++++-- paddle/phi/core/tensor_array.h | 4 ++++ 9 files changed, 56 insertions(+), 8 deletions(-) diff --git a/paddle/phi/core/selected_rows.cc b/paddle/phi/core/selected_rows.cc index dcf9c418215..ec2d0d61fae 100644 --- a/paddle/phi/core/selected_rows.cc +++ b/paddle/phi/core/selected_rows.cc @@ -23,4 +23,10 @@ SelectedRows::SelectedRows(const std::vector& rows, SelectedRows::SelectedRows() : impl_(std::make_shared()) {} +void SelectedRows::set_type(const DataType dtype) { impl_->set_type(dtype); } + +void SelectedRows::set_layout(const DataLayout layout) { + impl_->set_layout(layout); +} + } // namespace phi diff --git a/paddle/phi/core/selected_rows.h b/paddle/phi/core/selected_rows.h index aa528969fbf..dcd63fa83c0 100644 --- a/paddle/phi/core/selected_rows.h +++ b/paddle/phi/core/selected_rows.h @@ -139,13 +139,17 @@ class SelectedRows : public TensorBase, /// \return The data type of the tensor. DataType dtype() const noexcept override { return impl_->dtype(); } - void set_type(const DataType dtype) { impl_->set_type(dtype); } +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_type(const DataType dtype); +#endif /// \brief Returns the data layout of the tensor. /// \return The data layout of the tensor. DataLayout layout() const noexcept override { return impl_->layout(); } - void set_layout(const DataLayout layout) { impl_->set_layout(layout); } +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_layout(const DataLayout layout); +#endif /// \brief Returns the data place of the tensor. /// \return The data place of the tensor. diff --git a/paddle/phi/core/selected_rows_impl.cc b/paddle/phi/core/selected_rows_impl.cc index d5143be2e84..f0fbefe2fc5 100644 --- a/paddle/phi/core/selected_rows_impl.cc +++ b/paddle/phi/core/selected_rows_impl.cc @@ -208,4 +208,13 @@ void SelectedRowsImpl::Get(const phi::DenseTensor& ids, } } } + +void SelectedRowsImpl::set_type(const DataType dtype) { + value_->set_type(dtype); +} + +void SelectedRowsImpl::set_layout(const DataLayout layout) { + value_->set_layout(layout); +} + } // namespace phi diff --git a/paddle/phi/core/selected_rows_impl.h b/paddle/phi/core/selected_rows_impl.h index a1864ad3aa6..a29f66b9942 100644 --- a/paddle/phi/core/selected_rows_impl.h +++ b/paddle/phi/core/selected_rows_impl.h @@ -159,13 +159,17 @@ class SelectedRowsImpl { /// \return The data type of the tensor. DataType dtype() const noexcept { return value_->dtype(); } - void set_type(const DataType dtype) { value_->set_type(dtype); } +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_type(const DataType dtype); +#endif /// \brief Returns the data layout of the tensor. /// \return The data layout of the tensor. DataLayout layout() const noexcept { return value_->layout(); } - void set_layout(const DataLayout layout) { value_->set_layout(layout); } +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_layout(const DataLayout layout); +#endif /// \brief Returns the data place of the tensor. /// \return The data place of the tensor. diff --git a/paddle/phi/core/sparse_coo_tensor.cc b/paddle/phi/core/sparse_coo_tensor.cc index 6d3296e2852..b7b0d06de8a 100644 --- a/paddle/phi/core/sparse_coo_tensor.cc +++ b/paddle/phi/core/sparse_coo_tensor.cc @@ -84,6 +84,12 @@ int64_t SparseCooTensor::nnz() const { } } +void SparseCooTensor::set_type(const DataType dtype) { meta_.dtype = dtype; } + +void SparseCooTensor::set_layout(const DataLayout layout) { + meta_.layout = layout; +} + void SparseCooTensor::Resize(const DDim& dense_dims, const int64_t sparse_dim, const int64_t non_zero_num) { diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index 542f4e86277..0e9273f321f 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -104,13 +104,18 @@ class SparseCooTensor : public TensorBase, /// \brief Returns the data type of the tensor. /// \return The data type of the tensor. DataType dtype() const noexcept override { return meta_.dtype; } - void set_type(const DataType dtype) { meta_.dtype = dtype; } + +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_type(const DataType dtype); +#endif /// \brief Returns the data layout of the tensor. /// \return The data layout of the tensor. DataLayout layout() const noexcept override { return meta_.layout; } - void set_layout(const DataLayout layout) { meta_.layout = layout; } +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_layout(const DataLayout layout); +#endif /// \brief Returns the data place of the tensor. /// \return The data place of the tensor. diff --git a/paddle/phi/core/sparse_csr_tensor.cc b/paddle/phi/core/sparse_csr_tensor.cc index 0b4662760c0..32680106a96 100644 --- a/paddle/phi/core/sparse_csr_tensor.cc +++ b/paddle/phi/core/sparse_csr_tensor.cc @@ -88,6 +88,12 @@ void* SparseCsrTensor::AllocateFrom(Allocator* allocator, allocator, dtype, requested_size, fake_alloc); } +void SparseCsrTensor::set_type(const DataType dtype) { meta_.dtype = dtype; } + +void SparseCsrTensor::set_layout(const DataLayout layout) { + meta_.layout = layout; +} + void SparseCsrTensor::Resize(const DDim& dense_dims, const int64_t non_zero_num) { PADDLE_ENFORCE(this->initialized(), diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index ec9dd7ab790..8692c8d7a20 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -110,13 +110,17 @@ class SparseCsrTensor : public TensorBase, /// \return The data type of the tensor. DataType dtype() const noexcept override { return meta_.dtype; } - void set_type(const DataType dtype) { meta_.dtype = dtype; } +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_type(const DataType dtype); +#endif /// \brief Returns the data layout of the tensor. /// \return The data layout of the tensor. DataLayout layout() const noexcept override { return meta_.layout; } - void set_layout(const DataLayout layout) { meta_.layout = layout; } +#ifndef PADDLE_WITH_CUSTOM_KERNEL + void set_layout(const DataLayout layout); +#endif /// \brief Returns the data place of the tensor. /// \return The data place of the tensor. diff --git a/paddle/phi/core/tensor_array.h b/paddle/phi/core/tensor_array.h index 4fd8fe1df5e..f9ec8c7148f 100644 --- a/paddle/phi/core/tensor_array.h +++ b/paddle/phi/core/tensor_array.h @@ -65,11 +65,15 @@ class TensorArray : public TensorBase, DataType dtype() const override; +#ifndef PADDLE_WITH_CUSTOM_KERNEL void set_type(const DataType dtype); +#endif DataLayout layout() const override; +#ifndef PADDLE_WITH_CUSTOM_KERNEL void set_layout(const DataLayout layout); +#endif /// \brief This overrided function is not used in TensorArray. bool valid() const override; -- GitLab