From 3e80253ae1faeb37eb3eb458d93bd092c8c5fbcd Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 26 Jan 2022 15:39:49 +0800 Subject: [PATCH] [Move selected_rows PR #4] SelectedRows inherits from TensorBase. (#39162) * Added selected_rows and rw_lock to pten * Renamed the unit test target to fix CI * Removed Class SelectedRows in Fluid, changed include/cmake relationship, use pten::SelectedRows in Fluid * Remove rw_lock.h,rw_lock_test.cc in fluid * Use pten::RWLock and pten::AutoRDLock, fix CI * Use pten::SelectedRows * Use pten::SelectedRows * Fix to pass NPU CI * Selected_Rows inherits from TensorBase * Use pten::SelectedRows, to pass NPU CI * To fix NPU CI * To fix NPU CI again * Use paddle/pten/core/enforce and polish code --- paddle/pten/core/CMakeLists.txt | 2 +- paddle/pten/core/selected_rows.h | 56 +++++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index 757eed25974..0c5437ff6d0 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -24,7 +24,7 @@ cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) -cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim) +cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector pten_enforce ddim) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) diff --git a/paddle/pten/core/selected_rows.h b/paddle/pten/core/selected_rows.h index f5be0a906db..e12f59d02f2 100644 --- a/paddle/pten/core/selected_rows.h +++ b/paddle/pten/core/selected_rows.h @@ -24,15 +24,16 @@ limitations under the License. */ #include "paddle/pten/common/place.h" #include "paddle/pten/core/ddim.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/enforce.h" #include "paddle/pten/core/utils/rw_lock.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/enforce.h" namespace pten { -class SelectedRows { +class SelectedRows : public TensorBase, + public TypeInfoTraits { /* * @brief We can use the SelectedRows structure to reproduce a sparse table. * A sparse table is a key-value structure that the key is an `int64_t`, @@ -51,21 +52,19 @@ class SelectedRows { public: SelectedRows(const std::vector& rows, const int64_t& height) : rows_(rows), height_(height) { - value_.reset(new pten::DenseTensor()); + value_.reset(new DenseTensor()); rwlock_.reset(new RWLock); } SelectedRows() { height_ = 0; - value_.reset(new pten::DenseTensor()); + value_.reset(new DenseTensor()); rwlock_.reset(new RWLock); } - const pten::Place& place() const { return value_->place(); } + const DenseTensor& value() const { return *value_; } - const pten::DenseTensor& value() const { return *value_; } - - pten::DenseTensor* mutable_value() { return value_.get(); } + DenseTensor* mutable_value() { return value_.get(); } int64_t height() const { return height_; } @@ -109,8 +108,8 @@ class SelectedRows { * @return a list of pair which contains the non-exists key and the index in * the value */ - void Get(const pten::DenseTensor& ids, - pten::DenseTensor* value, + void Get(const DenseTensor& ids, + DenseTensor* value, bool auto_grown = false, bool is_test = false); @@ -149,6 +148,41 @@ class SelectedRows { return pten::framework::make_ddim(dims); } + /// \brief Returns the name of the class for type traits. + /// \return The name of the class. + static const char* name() { return "SelectedRows"; } + + /// \brief Returns the number of elements contained in tensor. + /// \return The number of elements contained in tensor. + int64_t numel() const override { return value_->numel(); }; + + /// \brief Returns the dims of the tensor. + /// \return The dims of the tensor. + const DDim& dims() const noexcept override { + return value_->dims(); + // return paddle::framework::make_ddim(dims); + } + + /// \brief Returns the data type of the tensor. + /// \return The data type of the tensor. + DataType dtype() const noexcept override { return value_->dtype(); } + + /// \brief Returns the data layout of the tensor. + /// \return The data layout of the tensor. + DataLayout layout() const noexcept override { return value_->layout(); } + + /// \brief Returns the data place of the tensor. + /// \return The data place of the tensor. + const Place& place() const override { return value_->place(); }; + + /// \brief Test whether the metadata is valid. + /// \return Whether the metadata is valid. + bool valid() const noexcept override { return value_->valid(); } + + /// \brief Test whether the storage is allocated. + /// return Whether the storage is allocated. + bool initialized() const override { return value_->initialized(); } + private: // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. // SelectedRows are simply concated when adding together. Until a @@ -156,7 +190,7 @@ class SelectedRows { paddle::framework::Vector rows_; std::unordered_map id_to_index_; // should not be used when rows_ has duplicate member - std::unique_ptr value_{nullptr}; + std::unique_ptr value_{nullptr}; int64_t height_; // height indicates the underline tensor's height std::unique_ptr rwlock_{nullptr}; }; -- GitLab