未验证 提交 6ff3596e 编写于 作者: C Chen Weihang 提交者: GitHub

add is dense tensor method (#38424)

上级 6554cc10
......@@ -204,6 +204,14 @@ class PADDLE_API Tensor final {
*/
DataLayout layout() const;
/**
* @brief Determine whether tensor is DenseTensor
*
* @return true
* @return false
*/
bool is_dense_tensor() const;
/* Part 3: Device and Backend methods */
/**
......
......@@ -58,15 +58,6 @@ limitations under the License. */
namespace paddle {
namespace experimental {
namespace detail {
inline bool IsDenseTensor(
const std::shared_ptr<pten::TensorBase> &tensor_impl) {
return tensor_impl->type_info().name() == "DenseTensor";
}
} // namespace detail
// declare cast api
Tensor cast(const Tensor &x, DataType out_dtype);
......@@ -118,7 +109,7 @@ void Tensor::reshape(const std::vector<int64_t> &shape) {
"reason: `reshape` means changing the tensor shape without "
"touching underlying data, this requires the total size of "
"the tensor to remain constant.";
if (detail::IsDenseTensor(impl_)) {
if (is_dense_tensor()) {
std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->set_meta(
pten::DenseTensorMeta(dtype(), framework::make_ddim(shape)));
} else {
......@@ -133,6 +124,10 @@ DataType Tensor::type() const { return impl_->dtype(); }
DataLayout Tensor::layout() const { return impl_->layout(); }
bool Tensor::is_dense_tensor() const {
return pten::DenseTensor::classof(impl_.get());
}
/* Part 3: Device and Backend methods */
PlaceType Tensor::place() const {
......@@ -153,7 +148,7 @@ bool Tensor::is_cuda() const {
template <typename T>
T *Tensor::mutable_data() {
if (detail::IsDenseTensor(impl_)) {
if (is_dense_tensor()) {
return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)
->mutable_data<T>();
}
......@@ -209,7 +204,7 @@ Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template <typename T>
const T *Tensor::data() const {
if (detail::IsDenseTensor(impl_)) {
if (is_dense_tensor()) {
return std::dynamic_pointer_cast<pten::DenseTensor>(impl_)->data<T>();
}
return nullptr;
......@@ -259,7 +254,7 @@ Tensor::data<paddle::platform::float16>();
// TODO(chenweihang): replace slice impl by API
Tensor Tensor::slice(const int64_t begin_idx, const int64_t end_idx) const {
if (detail::IsDenseTensor(impl_)) {
if (is_dense_tensor()) {
return Tensor(std::make_shared<pten::DenseTensor>(
std::move(pten::CompatibleDenseTensorUtils::Slice(
std::dynamic_pointer_cast<pten::DenseTensor>(impl_).get(),
......
......@@ -205,6 +205,11 @@ void TestInitilized() {
}
}
void TestJudgeTensorType() {
experimental::Tensor test_tensor(paddle::PlaceType::kCPU, {1, 1});
CHECK(test_tensor.is_dense_tensor() == true);
}
TEST(PtenTensor, All) {
VLOG(2) << "TestCopy";
GroupTestCopy();
......@@ -220,6 +225,8 @@ TEST(PtenTensor, All) {
GroupTestCast();
VLOG(2) << "TestInitilized";
TestInitilized();
VLOG(2) << "TestJudgeTensorType";
TestJudgeTensorType();
}
} // namespace tests
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册