未验证 提交 6ff3596e 编写于 作者: C Chen Weihang 提交者: GitHub

add is dense tensor method (#38424)

上级 6554cc10
...@@ -204,6 +204,14 @@ class PADDLE_API Tensor final { ...@@ -204,6 +204,14 @@ class PADDLE_API Tensor final {
*/ */
DataLayout layout() const; DataLayout layout() const;
/**
* @brief Determine whether tensor is DenseTensor
*
* @return true
* @return false
*/
bool is_dense_tensor() const;
/* Part 3: Device and Backend methods */ /* Part 3: Device and Backend methods */
/** /**
......
...@@ -58,15 +58,6 @@ limitations under the License. */ ...@@ -58,15 +58,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
namespace detail {
inline bool IsDenseTensor(
const std::shared_ptr<pten::TensorBase> &tensor_impl) {
return tensor_impl->type_info().name() == "DenseTensor";
}
} // namespace detail
// declare cast api // declare cast api
Tensor cast(const Tensor &x, DataType out_dtype); Tensor cast(const Tensor &x, DataType out_dtype);
...@@ -118,7 +109,7 @@ void Tensor::reshape(const std::vector<int64_t> &shape) { ...@@ -118,7 +109,7 @@ void Tensor::reshape(const std::vector<int64_t> &shape) {
"reason: `reshape` means changing the tensor shape without " "reason: `reshape` means changing the tensor shape without "
"touching underlying data, this requires the total size of " "touching underlying data, this requires the total size of "
"the tensor to remain constant."; "the tensor to remain constant.";
if (detail::IsDenseTensor(impl_)) { if (is_dense_tensor()) {
std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->set_meta( std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->set_meta(
pten::DenseTensorMeta(dtype(), framework::make_ddim(shape))); pten::DenseTensorMeta(dtype(), framework::make_ddim(shape)));
} else { } else {
...@@ -133,6 +124,10 @@ DataType Tensor::type() const { return impl_->dtype(); } ...@@ -133,6 +124,10 @@ DataType Tensor::type() const { return impl_->dtype(); }
DataLayout Tensor::layout() const { return impl_->layout(); } DataLayout Tensor::layout() const { return impl_->layout(); }
bool Tensor::is_dense_tensor() const {
return pten::DenseTensor::classof(impl_.get());
}
/* Part 3: Device and Backend methods */ /* Part 3: Device and Backend methods */
PlaceType Tensor::place() const { PlaceType Tensor::place() const {
...@@ -153,7 +148,7 @@ bool Tensor::is_cuda() const { ...@@ -153,7 +148,7 @@ bool Tensor::is_cuda() const {
template <typename T> template <typename T>
T *Tensor::mutable_data() { T *Tensor::mutable_data() {
if (detail::IsDenseTensor(impl_)) { if (is_dense_tensor()) {
return std::dynamic_pointer_cast<pten::DenseTensor>(impl_) return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)
->mutable_data<T>(); ->mutable_data<T>();
} }
...@@ -209,7 +204,7 @@ Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place); ...@@ -209,7 +204,7 @@ Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template <typename T> template <typename T>
const T *Tensor::data() const { const T *Tensor::data() const {
if (detail::IsDenseTensor(impl_)) { if (is_dense_tensor()) {
return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->data<T>(); return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->data<T>();
} }
return nullptr; return nullptr;
...@@ -259,7 +254,7 @@ Tensor::data<paddle::platform::float16>(); ...@@ -259,7 +254,7 @@ Tensor::data<paddle::platform::float16>();
// TODO(chenweihang): replace slice impl by API // TODO(chenweihang): replace slice impl by API
Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const { Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const {
if (detail::IsDenseTensor(impl_)) { if (is_dense_tensor()) {
return Tensor(std::make_shared<pten::DenseTensor>( return Tensor(std::make_shared<pten::DenseTensor>(
std::move(pten::CompatibleDenseTensorUtils::Slice( std::move(pten::CompatibleDenseTensorUtils::Slice(
std::dynamic_pointer_cast<pten::DenseTensor>(impl_).get(), std::dynamic_pointer_cast<pten::DenseTensor>(impl_).get(),
......
...@@ -205,6 +205,11 @@ void TestInitilized() { ...@@ -205,6 +205,11 @@ void TestInitilized() {
} }
} }
void TestJudgeTensorType() {
experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1});
CHECK(test_tensor.is_dense_tensor() == true);
}
TEST(PtenTensor, All) { TEST(PtenTensor, All) {
VLOG(2) << "TestCopy"; VLOG(2) << "TestCopy";
GroupTestCopy(); GroupTestCopy();
...@@ -220,6 +225,8 @@ TEST(PtenTensor, All) { ...@@ -220,6 +225,8 @@ TEST(PtenTensor, All) {
GroupTestCast(); GroupTestCast();
VLOG(2) << "TestInitilized"; VLOG(2) << "TestInitilized";
TestInitilized(); TestInitilized();
VLOG(2) << "TestJudgeTensorType";
TestJudgeTensorType();
} }
} // namespace tests } // namespace tests
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册