未验证 提交 f5ae67e8 编写于 作者: R Ruibiao Chen 提交者: GitHub

Isolate DenseTensor::set_type and DenseTensor::set_layout from header file (#52591)

* Isolate DenseTensor::set_type from header file

* Fix selected_rows
上级 9a0de116
......@@ -23,4 +23,10 @@ SelectedRows::SelectedRows(const std::vector<int64_t>& rows,
SelectedRows::SelectedRows()
: impl_(std::make_shared<phi::SelectedRowsImpl>()) {}
void SelectedRows::set_type(const DataType dtype) { impl_->set_type(dtype); }
void SelectedRows::set_layout(const DataLayout layout) {
impl_->set_layout(layout);
}
} // namespace phi
......@@ -139,13 +139,17 @@ class SelectedRows : public TensorBase,
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return impl_->dtype(); }
void set_type(const DataType dtype) { impl_->set_type(dtype); }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return impl_->layout(); }
void set_layout(const DataLayout layout) { impl_->set_layout(layout); }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
......
......@@ -208,4 +208,13 @@ void SelectedRowsImpl::Get(const phi::DenseTensor& ids,
}
}
}
void SelectedRowsImpl::set_type(const DataType dtype) {
value_->set_type(dtype);
}
void SelectedRowsImpl::set_layout(const DataLayout layout) {
value_->set_layout(layout);
}
} // namespace phi
......@@ -159,13 +159,17 @@ class SelectedRowsImpl {
/// \return The data type of the tensor.
DataType dtype() const noexcept { return value_->dtype(); }
void set_type(const DataType dtype) { value_->set_type(dtype); }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept { return value_->layout(); }
void set_layout(const DataLayout layout) { value_->set_layout(layout); }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
......
......@@ -84,6 +84,12 @@ int64_t SparseCooTensor::nnz() const {
}
}
void SparseCooTensor::set_type(const DataType dtype) { meta_.dtype = dtype; }
void SparseCooTensor::set_layout(const DataLayout layout) {
meta_.layout = layout;
}
void SparseCooTensor::Resize(const DDim& dense_dims,
const int64_t sparse_dim,
const int64_t non_zero_num) {
......
......@@ -104,13 +104,18 @@ class SparseCooTensor : public TensorBase,
/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return meta_.dtype; }
void set_type(const DataType dtype) { meta_.dtype = dtype; }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return meta_.layout; }
void set_layout(const DataLayout layout) { meta_.layout = layout; }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
......
......@@ -88,6 +88,12 @@ void* SparseCsrTensor::AllocateFrom(Allocator* allocator,
allocator, dtype, requested_size, fake_alloc);
}
void SparseCsrTensor::set_type(const DataType dtype) { meta_.dtype = dtype; }
void SparseCsrTensor::set_layout(const DataLayout layout) {
meta_.layout = layout;
}
void SparseCsrTensor::Resize(const DDim& dense_dims,
const int64_t non_zero_num) {
PADDLE_ENFORCE(this->initialized(),
......
......@@ -110,13 +110,17 @@ class SparseCsrTensor : public TensorBase,
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return meta_.dtype; }
void set_type(const DataType dtype) { meta_.dtype = dtype; }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return meta_.layout; }
void set_layout(const DataLayout layout) { meta_.layout = layout; }
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
......
......@@ -65,11 +65,15 @@ class TensorArray : public TensorBase,
DataType dtype() const override;
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_type(const DataType dtype);
#endif
DataLayout layout() const override;
#ifndef PADDLE_WITH_CUSTOM_KERNEL
void set_layout(const DataLayout layout);
#endif
/// \brief This overrided function is not used in TensorArray.
bool valid() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册