未验证 提交 3e80253a 编写于 作者: W Weilong Wu 提交者: GitHub

[Move selected_rows PR #4] SelectedRows inherits from TensorBase. (#39162)

* Added selected_rows and rw_lock to pten

* Renamed the unit test target to fix CI

* Removed Class SelectedRows in Fluid, changed include/cmake relationship, use pten::SelectedRows in Fluid

* Remove rw_lock.h,rw_lock_test.cc in fluid

* Use pten::RWLock and pten::AutoRDLock, fix CI

* Use pten::SelectedRows

* Use pten::SelectedRows

* Fix to pass NPU CI

* Selected_Rows inherits from TensorBase

* Use pten::SelectedRows, to pass NPU CI

* To fix NPU CI

* To fix NPU CI again

* Use paddle/pten/core/enforce and polish code
上级 d9acc87e
......@@ -24,7 +24,7 @@ cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )
cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector pten_enforce ddim)
cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
......
......@@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/utils/rw_lock.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
namespace pten {
class SelectedRows {
class SelectedRows : public TensorBase,
public TypeInfoTraits<TensorBase, SelectedRows> {
/*
* @brief We can use the SelectedRows structure to reproduce a sparse table.
* A sparse table is a key-value structure that the key is an `int64_t`,
......@@ -51,21 +52,19 @@ class SelectedRows {
public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}
SelectedRows() {
height_ = 0;
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}
const pten::Place& place() const { return value_->place(); }
const DenseTensor& value() const { return *value_; }
const pten::DenseTensor& value() const { return *value_; }
pten::DenseTensor* mutable_value() { return value_.get(); }
DenseTensor* mutable_value() { return value_.get(); }
int64_t height() const { return height_; }
......@@ -109,8 +108,8 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
void Get(const pten::DenseTensor& ids,
pten::DenseTensor* value,
void Get(const DenseTensor& ids,
DenseTensor* value,
bool auto_grown = false,
bool is_test = false);
......@@ -149,6 +148,41 @@ class SelectedRows {
return pten::framework::make_ddim(dims);
}
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "SelectedRows"; }
/// \brief Returns the number of elements contained in tensor.
/// \return The number of elements contained in tensor.
int64_t numel() const override { return value_->numel(); };
/// \brief Returns the dims of the tensor.
/// \return The dims of the tensor.
const DDim& dims() const noexcept override {
return value_->dims();
// return paddle::framework::make_ddim(dims);
}
/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return value_->dtype(); }
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return value_->layout(); }
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override { return value_->place(); };
/// \brief Test whether the metadata is valid.
/// \return Whether the metadata is valid.
bool valid() const noexcept override { return value_->valid(); }
/// \brief Test whether the storage is allocated.
/// return Whether the storage is allocated.
bool initialized() const override { return value_->initialized(); }
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simply concated when adding together. Until a
......@@ -156,7 +190,7 @@ class SelectedRows {
paddle::framework::Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when rows_ has duplicate member
std::unique_ptr<pten::DenseTensor> value_{nullptr};
std::unique_ptr<DenseTensor> value_{nullptr};
int64_t height_; // height indicates the underline tensor's height
std::unique_ptr<RWLock> rwlock_{nullptr};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册