未验证 提交 f5ae67e8 编写于 作者: R Ruibiao Chen 提交者: GitHub

Isolate DenseTensor::set_type and DenseTensor::set_layout from header file (#52591)

* Isolate DenseTensor::set_type from header file

* Fix selected_rows
上级 9a0de116
...@@ -23,4 +23,10 @@ SelectedRows::SelectedRows(const std::vector<int64_t>& rows, ...@@ -23,4 +23,10 @@ SelectedRows::SelectedRows(const std::vector<int64_t>& rows,
SelectedRows::SelectedRows() SelectedRows::SelectedRows()
: impl_(std::make_shared<phi::SelectedRowsImpl>()) {} : impl_(std::make_shared<phi::SelectedRowsImpl>()) {}
void SelectedRows::set_type(const DataType dtype) { impl_->set_type(dtype); }
void SelectedRows::set_layout(const DataLayout layout) {
impl_->set_layout(layout);
}
} // namespace phi } // namespace phi
...@@ -139,13 +139,17 @@ class SelectedRows : public TensorBase, ...@@ -139,13 +139,17 @@ class SelectedRows : public TensorBase,
/// \return The data type of the tensor. /// \return The data type of the tensor.
DataType dtype() const noexcept override { return impl_->dtype(); } DataType dtype() const noexcept override { return impl_->dtype(); }
void set_type(const DataType dtype) { impl_->set_type(dtype); } #ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor. /// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor. /// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return impl_->layout(); } DataLayout layout() const noexcept override { return impl_->layout(); }
void set_layout(const DataLayout layout) { impl_->set_layout(layout); } #ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor. /// \brief Returns the data place of the tensor.
/// \return The data place of the tensor. /// \return The data place of the tensor.
......
...@@ -208,4 +208,13 @@ void SelectedRowsImpl::Get(const phi::DenseTensor& ids, ...@@ -208,4 +208,13 @@ void SelectedRowsImpl::Get(const phi::DenseTensor& ids,
} }
} }
} }
void SelectedRowsImpl::set_type(const DataType dtype) {
value_->set_type(dtype);
}
void SelectedRowsImpl::set_layout(const DataLayout layout) {
value_->set_layout(layout);
}
} // namespace phi } // namespace phi
...@@ -159,13 +159,17 @@ class SelectedRowsImpl { ...@@ -159,13 +159,17 @@ class SelectedRowsImpl {
/// \return The data type of the tensor. /// \return The data type of the tensor.
DataType dtype() const noexcept { return value_->dtype(); } DataType dtype() const noexcept { return value_->dtype(); }
void set_type(const DataType dtype) { value_->set_type(dtype); } #ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor. /// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor. /// \return The data layout of the tensor.
DataLayout layout() const noexcept { return value_->layout(); } DataLayout layout() const noexcept { return value_->layout(); }
void set_layout(const DataLayout layout) { value_->set_layout(layout); } #ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor. /// \brief Returns the data place of the tensor.
/// \return The data place of the tensor. /// \return The data place of the tensor.
......
...@@ -84,6 +84,12 @@ int64_t SparseCooTensor::nnz() const { ...@@ -84,6 +84,12 @@ int64_t SparseCooTensor::nnz() const {
} }
} }
void SparseCooTensor::set_type(const DataType dtype) { meta_.dtype = dtype; }
void SparseCooTensor::set_layout(const DataLayout layout) {
meta_.layout = layout;
}
void SparseCooTensor::Resize(const DDim& dense_dims, void SparseCooTensor::Resize(const DDim& dense_dims,
const int64_t sparse_dim, const int64_t sparse_dim,
const int64_t non_zero_num) { const int64_t non_zero_num) {
......
...@@ -104,13 +104,18 @@ class SparseCooTensor : public TensorBase, ...@@ -104,13 +104,18 @@ class SparseCooTensor : public TensorBase,
/// \brief Returns the data type of the tensor. /// \brief Returns the data type of the tensor.
/// \return The data type of the tensor. /// \return The data type of the tensor.
DataType dtype() const noexcept override { return meta_.dtype; } DataType dtype() const noexcept override { return meta_.dtype; }
void set_type(const DataType dtype) { meta_.dtype = dtype; }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor. /// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor. /// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return meta_.layout; } DataLayout layout() const noexcept override { return meta_.layout; }
void set_layout(const DataLayout layout) { meta_.layout = layout; } #ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor. /// \brief Returns the data place of the tensor.
/// \return The data place of the tensor. /// \return The data place of the tensor.
......
...@@ -88,6 +88,12 @@ void* SparseCsrTensor::AllocateFrom(Allocator* allocator, ...@@ -88,6 +88,12 @@ void* SparseCsrTensor::AllocateFrom(Allocator* allocator,
allocator, dtype, requested_size, fake_alloc); allocator, dtype, requested_size, fake_alloc);
} }
void SparseCsrTensor::set_type(const DataType dtype) { meta_.dtype = dtype; }
void SparseCsrTensor::set_layout(const DataLayout layout) {
meta_.layout = layout;
}
void SparseCsrTensor::Resize(const DDim& dense_dims, void SparseCsrTensor::Resize(const DDim& dense_dims,
const int64_t non_zero_num) { const int64_t non_zero_num) {
PADDLE_ENFORCE(this->initialized(), PADDLE_ENFORCE(this->initialized(),
......
...@@ -110,13 +110,17 @@ class SparseCsrTensor : public TensorBase, ...@@ -110,13 +110,17 @@ class SparseCsrTensor : public TensorBase,
/// \return The data type of the tensor. /// \return The data type of the tensor.
DataType dtype() const noexcept override { return meta_.dtype; } DataType dtype() const noexcept override { return meta_.dtype; }
void set_type(const DataType dtype) { meta_.dtype = dtype; } #ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor. /// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor. /// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return meta_.layout; } DataLayout layout() const noexcept override { return meta_.layout; }
void set_layout(const DataLayout layout) { meta_.layout = layout; } #ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor. /// \brief Returns the data place of the tensor.
/// \return The data place of the tensor. /// \return The data place of the tensor.
......
...@@ -65,11 +65,15 @@ class TensorArray : public TensorBase, ...@@ -65,11 +65,15 @@ class TensorArray : public TensorBase,
DataType dtype() const override; DataType dtype() const override;
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype); void set_type(const DataType dtype);
#endif
DataLayout layout() const override; DataLayout layout() const override;
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout); void set_layout(const DataLayout layout);
#endif
/// \brief This overrided function is not used in TensorArray. /// \brief This overrided function is not used in TensorArray.
bool valid() const override; bool valid() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册