未验证 提交 2bafd338 编写于 作者: W Weilong Wu 提交者: GitHub

[Move selected_rows PR #3] Change the relationship of [include/Cmake]. (#39128)

* Added selected_rows and rw_lock to pten

* Renamed the unit test target to fix CI

* Removed Class SelectedRows in Fluid, changed include/cmake relationship, use pten::SelectedRows in Fluid

* Remove rw_lock.h,rw_lock_test.cc in fluid

* Use pten::RWLock and pten::AutoRDLock, fix CI

* Use pten::SelectedRows

* Use pten::SelectedRows

* Fix to pass NPU CI

* Use pten::SelectedRows, to pass NPU CI

* To fix NPU CI

* To fix NPU CI again
上级 3825b40f
...@@ -49,7 +49,7 @@ class PSCore; ...@@ -49,7 +49,7 @@ class PSCore;
using framework::LoDTensor; using framework::LoDTensor;
using framework::Scope; using framework::Scope;
using framework::SelectedRows; using pten::SelectedRows;
using framework::Variable; using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>; using RpcCtxMap = std::unordered_map<std::string, CommContext>;
......
...@@ -76,7 +76,7 @@ void SerializeToMultiVarMsgAndIOBuf( ...@@ -76,7 +76,7 @@ void SerializeToMultiVarMsgAndIOBuf(
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf); SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf); SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
} }
iobuf->append(temp_iobuf); iobuf->append(temp_iobuf);
...@@ -127,7 +127,7 @@ void SerializeLodTensor(framework::Variable* var, ...@@ -127,7 +127,7 @@ void SerializeLodTensor(framework::Variable* var,
void SerializeSelectedRows(framework::Variable* var, void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg, const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) { butil::IOBuf* iobuf) {
framework::SelectedRows* slr = var->GetMutable<framework::SelectedRows>(); pten::SelectedRows* slr = var->GetMutable<pten::SelectedRows>();
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows(); auto* rows = slr->mutable_rows();
...@@ -255,7 +255,7 @@ void DeserializeSelectedRows( ...@@ -255,7 +255,7 @@ void DeserializeSelectedRows(
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) { const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace(); const auto place = ctx.GetPlace();
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<pten::SelectedRows>();
framework::Tensor* tensor = slr->mutable_value(); framework::Tensor* tensor = slr->mutable_value();
slr->set_height(msg.slr_height()); slr->set_height(msg.slr_height());
std::vector<int64_t> tmp_rows(msg.dims()[0]); std::vector<int64_t> tmp_rows(msg.dims()[0]);
......
...@@ -28,7 +28,7 @@ namespace paddle { ...@@ -28,7 +28,7 @@ namespace paddle {
namespace distributed { namespace distributed {
using framework::LoDTensor; using framework::LoDTensor;
using framework::SelectedRows; using pten::SelectedRows;
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
...@@ -293,7 +293,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -293,7 +293,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
std::vector<float *> push_g_vec; std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(var_name); auto *send_var = scope.FindVar(var_name);
auto *tensor = send_var->GetMutable<SelectedRows>(); auto *tensor = send_var->GetMutable<pten::SelectedRows>();
auto dim = tensor->value().dims()[1]; auto dim = tensor->value().dims()[1];
std::transform(tensor->rows().begin(), tensor->rows().end(), std::transform(tensor->rows().begin(), tensor->rows().end(),
std::back_inserter(sparse_push_keys), std::back_inserter(sparse_push_keys),
...@@ -1012,10 +1012,10 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -1012,10 +1012,10 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto *var = scope.FindVar(table_name); auto *var = scope.FindVar(table_name);
PADDLE_ENFORCE_EQ(var->IsType<framework::SelectedRows>(), true, PADDLE_ENFORCE_EQ(var->IsType<pten::SelectedRows>(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Only need to send Sparse Grad in Geo mode.")); "Only need to send Sparse Grad in Geo mode."));
auto &rows = var->Get<framework::SelectedRows>().rows(); auto &rows = var->Get<pten::SelectedRows>().rows();
// insert ids which has not been record // insert ids which has not been record
for (size_t j = 0; j < rows.size(); j++) { for (size_t j = 0; j < rows.size(); j++) {
...@@ -1290,7 +1290,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, ...@@ -1290,7 +1290,7 @@ void GeoCommunicator::SendSparse(const std::string &varname,
auto cpu_ctx = paddle::platform::CPUDeviceContext(); auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(varname); auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>(); auto *t_delta = var_delta->GetMutable<pten::SelectedRows>();
auto *var_t_value = t_delta->mutable_value(); auto *var_t_value = t_delta->mutable_value();
var_t_value->Resize({static_cast<int64_t>(sparse_ids.size()), dims1}); var_t_value->Resize({static_cast<int64_t>(sparse_ids.size()), dims1});
auto *t_value = var_t_value->mutable_data<float>(cpu_ctx.GetPlace()); auto *t_value = var_t_value->mutable_data<float>(cpu_ctx.GetPlace());
......
...@@ -193,15 +193,15 @@ inline void MergeVars(const std::string &var_name, ...@@ -193,15 +193,15 @@ inline void MergeVars(const std::string &var_name,
result.device(*cpu_ctx.eigen_device()) = result.device(*cpu_ctx.eigen_device()) =
result / static_cast<T>(vars.size()); result / static_cast<T>(vars.size());
} }
} else if (var0->IsType<framework::SelectedRows>()) { } else if (var0->IsType<pten::SelectedRows>()) {
auto &slr0 = var0->Get<framework::SelectedRows>(); auto &slr0 = var0->Get<pten::SelectedRows>();
auto *out_slr = out_var->GetMutable<framework::SelectedRows>(); auto *out_slr = out_var->GetMutable<pten::SelectedRows>();
out_slr->mutable_rows()->clear(); out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place); out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows *> inputs; std::vector<const pten::SelectedRows *> inputs;
inputs.reserve(vars.size()); inputs.reserve(vars.size());
for (auto &var : vars) { for (auto &var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>()); inputs.push_back(&var->Get<pten::SelectedRows>());
} }
auto dev_ctx = paddle::platform::CPUDeviceContext(); auto dev_ctx = paddle::platform::CPUDeviceContext();
if (merge_add) { if (merge_add) {
......
...@@ -39,8 +39,10 @@ ...@@ -39,8 +39,10 @@
#include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_table.h" #include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h" #include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/pten/core/utils/rw_lock.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class GraphShard { class GraphShard {
......
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
#include "paddle/fluid/distributed/table/depends/initializers.h" #include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#include "paddle/fluid/distributed/table/depends/sparse.h" #include "paddle/fluid/distributed/table/depends/sparse.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/pten/core/utils/rw_lock.h"
#define PSERVER_SAVE_SUFFIX ".shard" #define PSERVER_SAVE_SUFFIX ".shard"
...@@ -110,7 +110,7 @@ struct Meta { ...@@ -110,7 +110,7 @@ struct Meta {
class CommonSparseTable : public SparseTable { class CommonSparseTable : public SparseTable {
public: public:
CommonSparseTable() { rwlock_.reset(new framework::RWLock); } CommonSparseTable() { rwlock_.reset(new pten::RWLock); }
virtual ~CommonSparseTable() {} virtual ~CommonSparseTable() {}
// unused method begin // unused method begin
...@@ -193,7 +193,7 @@ class CommonSparseTable : public SparseTable { ...@@ -193,7 +193,7 @@ class CommonSparseTable : public SparseTable {
std::shared_ptr<SparseOptimizer> optimizer_; std::shared_ptr<SparseOptimizer> optimizer_;
std::vector<std::shared_ptr<ValueBlock>> shard_values_; std::vector<std::shared_ptr<ValueBlock>> shard_values_;
std::unordered_map<uint64_t, ReservoirValue<float>> pull_reservoir_; std::unordered_map<uint64_t, ReservoirValue<float>> pull_reservoir_;
std::unique_ptr<framework::RWLock> rwlock_{nullptr}; std::unique_ptr<pten::RWLock> rwlock_{nullptr};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
#include "paddle/fluid/distributed/thirdparty/round_robin.h" #include "paddle/fluid/distributed/thirdparty/round_robin.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
...@@ -43,6 +42,7 @@ ...@@ -43,6 +42,7 @@
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/pten/backends/dynload/port.h" #include "paddle/pten/backends/dynload/port.h"
#include "paddle/pten/core/utils/rw_lock.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
#include "paddle/fluid/distributed/table/depends/initializers.h" #include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#include "paddle/fluid/distributed/table/depends/sparse.h" #include "paddle/fluid/distributed/table/depends/sparse.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/pten/core/utils/rw_lock.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -56,7 +56,7 @@ void CreateVarsOnScope(framework::Scope* scope, platform::Place* place, ...@@ -56,7 +56,7 @@ void CreateVarsOnScope(framework::Scope* scope, platform::Place* place,
// var 3 // var 3
framework::Variable* var3 = scope->Var("x3"); framework::Variable* var3 = scope->Var("x3");
auto* slr = var3->GetMutable<framework::SelectedRows>(); auto* slr = var3->GetMutable<pten::SelectedRows>();
slr->set_height(564); slr->set_height(564);
auto* tensor3 = slr->mutable_value(); auto* tensor3 = slr->mutable_value();
auto* rows = slr->mutable_rows(); auto* rows = slr->mutable_rows();
...@@ -111,7 +111,7 @@ void RunMultiVarMsg(platform::Place place) { ...@@ -111,7 +111,7 @@ void RunMultiVarMsg(platform::Place place) {
// check var3 // check var3
framework::Variable* var3 = scope_recv.FindVar("x3"); framework::Variable* var3 = scope_recv.FindVar("x3");
auto* slr = var3->GetMutable<framework::SelectedRows>(); auto* slr = var3->GetMutable<pten::SelectedRows>();
EXPECT_EQ(slr->rows().size(), 564); EXPECT_EQ(slr->rows().size(), 564);
for (int i = 0; i < 564; ++i) { for (int i = 0; i < 564; ++i) {
EXPECT_EQ(slr->rows()[i], i); EXPECT_EQ(slr->rows()[i], i);
......
...@@ -197,9 +197,8 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { ...@@ -197,9 +197,8 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
out_var->GetMutable<paddle::framework::LoDTensor>(); out_var->GetMutable<paddle::framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims()); out_lod_tensor->Resize(in_lod_tensor.dims());
} else { } else {
auto& in_sele_rows = in_var->Get<paddle::framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<pten::SelectedRows>();
auto out_sele_rows = auto out_sele_rows = out_var->GetMutable<pten::SelectedRows>();
out_var->GetMutable<paddle::framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows()); out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height()); out_sele_rows->set_height(in_sele_rows.height());
...@@ -368,8 +367,8 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { ...@@ -368,8 +367,8 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
"Input variable should not be null")); "Input variable should not be null"));
if (var->IsType<paddle::framework::LoDTensor>()) { if (var->IsType<paddle::framework::LoDTensor>()) {
return var->Get<paddle::framework::LoDTensor>().dims(); return var->Get<paddle::framework::LoDTensor>().dims();
} else if (var->IsType<paddle::framework::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
return var->Get<paddle::framework::SelectedRows>().GetCompleteDims(); return var->Get<pten::SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW(paddle::platform::errors::PermissionDenied( PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Only LoDTensor/SelectedRows support 'GetDim', but Variables " "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
...@@ -385,8 +384,8 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { ...@@ -385,8 +384,8 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
void SetDim(paddle::framework::Variable* var, const DDim& dim) { void SetDim(paddle::framework::Variable* var, const DDim& dim) {
if (var->IsType<paddle::framework::LoDTensor>()) { if (var->IsType<paddle::framework::LoDTensor>()) {
var->GetMutable<paddle::framework::LoDTensor>()->Resize(dim); var->GetMutable<paddle::framework::LoDTensor>()->Resize(dim);
} else if (var->IsType<paddle::framework::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
var->GetMutable<paddle::framework::SelectedRows>()->set_height(dim[0]); var->GetMutable<pten::SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW(paddle::platform::errors::PermissionDenied( PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Variable type_id %s, expect LoDTensor/SelectedRows.")); "Variable type_id %s, expect LoDTensor/SelectedRows."));
......
...@@ -32,8 +32,8 @@ const paddle::framework::Tensor* GetTensorFromVar( ...@@ -32,8 +32,8 @@ const paddle::framework::Tensor* GetTensorFromVar(
const paddle::framework::Variable& var) { const paddle::framework::Variable& var) {
if (var.IsType<paddle::framework::LoDTensor>()) { if (var.IsType<paddle::framework::LoDTensor>()) {
return &(var.Get<paddle::framework::LoDTensor>()); return &(var.Get<paddle::framework::LoDTensor>());
} else if (var.IsType<paddle::framework::SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
return &(var.Get<paddle::framework::SelectedRows>().value()); return &(var.Get<pten::SelectedRows>().value());
} else { } else {
return nullptr; return nullptr;
} }
......
...@@ -32,7 +32,7 @@ void InitializeVariable(paddle::framework::Variable *var, ...@@ -32,7 +32,7 @@ void InitializeVariable(paddle::framework::Variable *var,
if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) { if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<paddle::framework::LoDTensor>(); var->GetMutable<paddle::framework::LoDTensor>();
} else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) { } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
var->GetMutable<paddle::framework::SelectedRows>(); var->GetMutable<pten::SelectedRows>();
} else if (var_type == paddle::framework::proto::VarType::FEED_MINIBATCH) { } else if (var_type == paddle::framework::proto::VarType::FEED_MINIBATCH) {
var->GetMutable<paddle::framework::FeedList>(); var->GetMutable<paddle::framework::FeedList>();
} else if (var_type == paddle::framework::proto::VarType::FETCH_LIST) { } else if (var_type == paddle::framework::proto::VarType::FETCH_LIST) {
...@@ -72,9 +72,9 @@ void CopyVariable(const paddle::framework::Variable &src_var, ...@@ -72,9 +72,9 @@ void CopyVariable(const paddle::framework::Variable &src_var,
auto &src_tensor = src_var.Get<paddle::framework::LoDTensor>(); auto &src_tensor = src_var.Get<paddle::framework::LoDTensor>();
tmp_grad_tensor->set_lod(src_tensor.lod()); tmp_grad_tensor->set_lod(src_tensor.lod());
paddle::framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor); paddle::framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor);
} else if (src_var.IsType<paddle::framework::SelectedRows>()) { } else if (src_var.IsType<pten::SelectedRows>()) {
auto &src_slr = src_var.Get<paddle::framework::SelectedRows>(); auto &src_slr = src_var.Get<pten::SelectedRows>();
auto *tmp_grad_slr = dst_var->GetMutable<paddle::framework::SelectedRows>(); auto *tmp_grad_slr = dst_var->GetMutable<pten::SelectedRows>();
tmp_grad_slr->set_rows(src_slr.rows()); tmp_grad_slr->set_rows(src_slr.rows());
tmp_grad_slr->set_height(src_slr.height()); tmp_grad_slr->set_height(src_slr.height());
auto &src_t = src_slr.value(); auto &src_t = src_slr.value();
...@@ -89,8 +89,8 @@ paddle::framework::proto::VarType::Type GetDtypeFromVar( ...@@ -89,8 +89,8 @@ paddle::framework::proto::VarType::Type GetDtypeFromVar(
const paddle::framework::Variable &var) { const paddle::framework::Variable &var) {
if (var.IsType<paddle::framework::LoDTensor>()) { if (var.IsType<paddle::framework::LoDTensor>()) {
return var.Get<paddle::framework::LoDTensor>().type(); return var.Get<paddle::framework::LoDTensor>().type();
} else if (var.IsType<paddle::framework::SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
return var.Get<paddle::framework::SelectedRows>().value().type(); return var.Get<pten::SelectedRows>().value().type();
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows.", "Variable type is %s, expect LoDTensor or SelectedRows.",
...@@ -101,8 +101,8 @@ const paddle::platform::Place &GetPlaceFromVar( ...@@ -101,8 +101,8 @@ const paddle::platform::Place &GetPlaceFromVar(
const paddle::framework::Variable &var) { const paddle::framework::Variable &var) {
if (var.IsType<paddle::framework::LoDTensor>()) { if (var.IsType<paddle::framework::LoDTensor>()) {
return var.Get<paddle::framework::LoDTensor>().place(); return var.Get<paddle::framework::LoDTensor>().place();
} else if (var.IsType<paddle::framework::SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
return var.Get<paddle::framework::SelectedRows>().place(); return var.Get<pten::SelectedRows>().place();
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows.", "Variable type is %s, expect LoDTensor or SelectedRows.",
......
...@@ -383,7 +383,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto boost) ...@@ -383,7 +383,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto boost)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc) proto_desc)
cc_library(selected_rows_utils SRCS selected_rows_utils.cc DEPS tensor) cc_library(selected_rows_utils SRCS selected_rows_utils.cc DEPS selected_rows)
cc_test(selected_rows_utils_test SRCS selected_rows_utils_test.cc DEPS selected_rows_utils) cc_test(selected_rows_utils_test SRCS selected_rows_utils_test.cc DEPS selected_rows_utils)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type)
...@@ -393,10 +393,6 @@ cc_test(tuple_test SRCS tuple_test.cc ) ...@@ -393,10 +393,6 @@ cc_test(tuple_test SRCS tuple_test.cc )
cc_test(inlined_vector_test SRCS inlined_vector_test.cc) cc_test(inlined_vector_test SRCS inlined_vector_test.cc)
if (NOT WIN32)
cc_test(rw_lock_test SRCS rw_lock_test.cc)
endif (NOT WIN32)
cc_library(dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack) cc_library(dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack)
cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog) cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog)
......
...@@ -120,9 +120,9 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor, ...@@ -120,9 +120,9 @@ void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
tran_lod_tensor->set_format(in_lod_tensor.format()); tran_lod_tensor->set_format(in_lod_tensor.format());
#endif #endif
tran_lod_tensor->ShareDataWith(tensor); tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) { } else if (in_var.IsType<pten::SelectedRows>()) {
auto &in_selected_rows = in_var.Get<SelectedRows>(); auto &in_selected_rows = in_var.Get<pten::SelectedRows>();
auto *trans_selected_rows = out_var->GetMutable<SelectedRows>(); auto *trans_selected_rows = out_var->GetMutable<pten::SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height()); trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows()); trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor); trans_selected_rows->mutable_value()->ShareDataWith(tensor);
......
...@@ -237,7 +237,7 @@ struct TestBroadcastOpHandle { ...@@ -237,7 +237,7 @@ struct TestBroadcastOpHandle {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable %s is not found in scope.", var, platform::errors::NotFound("Variable %s is not found in scope.",
varname)); varname));
auto selected_rows = var->GetMutable<f::SelectedRows>(); auto selected_rows = var->GetMutable<pten::SelectedRows>();
auto value = selected_rows->mutable_value(); auto value = selected_rows->mutable_value();
value->mutable_data<float>(kDims, place_list_[input_scope_idx]); value->mutable_data<float>(kDims, place_list_[input_scope_idx]);
selected_rows->set_height(height); selected_rows->set_height(height);
...@@ -256,7 +256,7 @@ struct TestBroadcastOpHandle { ...@@ -256,7 +256,7 @@ struct TestBroadcastOpHandle {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable %s is not found in scope.", var, platform::errors::NotFound("Variable %s is not found in scope.",
varname)); varname));
auto& selected_rows = var->Get<f::SelectedRows>(); auto& selected_rows = var->Get<pten::SelectedRows>();
auto rt = selected_rows.value(); auto rt = selected_rows.value();
PADDLE_ENFORCE_EQ(selected_rows.height(), height, PADDLE_ENFORCE_EQ(selected_rows.height(), height,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -129,9 +129,10 @@ void EagerDeletionOpHandle::RunImpl() { ...@@ -129,9 +129,10 @@ void EagerDeletionOpHandle::RunImpl() {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder()); garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
garbages.emplace_back( garbages.emplace_back(var->GetMutable<pten::SelectedRows>()
var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder()); ->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
auto *tensor_arr = var->GetMutable<LoDTensorArray>(); auto *tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto &t : *tensor_arr) { for (auto &t : *tensor_arr) {
......
...@@ -64,14 +64,14 @@ void GatherOpHandle::RunImpl() { ...@@ -64,14 +64,14 @@ void GatherOpHandle::RunImpl() {
platform::errors::NotFound("The variable '%s' is not found in the scope.", platform::errors::NotFound("The variable '%s' is not found in the scope.",
in_0_handle->name())); in_0_handle->name()));
PADDLE_ENFORCE_EQ(pre_in_var->IsType<framework::SelectedRows>(), true, PADDLE_ENFORCE_EQ(pre_in_var->IsType<pten::SelectedRows>(), true,
platform::errors::Unimplemented( platform::errors::Unimplemented(
"Currently, gather_op only supports SelectedRows.")); "Currently, gather_op only supports SelectedRows."));
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(); WaitInputVarGenerated();
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>(); auto &pre_in_value = pre_in_var->Get<pten::SelectedRows>();
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
std::vector<Tensor> in_tensors; std::vector<Tensor> in_tensors;
...@@ -85,7 +85,7 @@ void GatherOpHandle::RunImpl() { ...@@ -85,7 +85,7 @@ void GatherOpHandle::RunImpl() {
"The variable '%s' is not found in the scope.", in_handle->name())); "The variable '%s' is not found in the scope.", in_handle->name()));
VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var); VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var);
auto &in_sr_value = in_var->Get<framework::SelectedRows>(); auto &in_sr_value = in_var->Get<pten::SelectedRows>();
auto &in_sr_rows = in_sr_value.rows(); auto &in_sr_rows = in_sr_value.rows();
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
...@@ -108,7 +108,7 @@ void GatherOpHandle::RunImpl() { ...@@ -108,7 +108,7 @@ void GatherOpHandle::RunImpl() {
out_var, out_var,
platform::errors::NotFound("The variable '%s' is not found in the scope.", platform::errors::NotFound("The variable '%s' is not found in the scope.",
out_var_handle->name())); out_var_handle->name()));
auto out_value = out_var->GetMutable<framework::SelectedRows>(); auto out_value = out_var->GetMutable<pten::SelectedRows>();
out_value->set_height(pre_in_value.height()); out_value->set_height(pre_in_value.height());
out_value->set_rows(out_rows); out_value->set_rows(out_rows);
size_t rows = out_rows.size(); size_t rows = out_rows.size();
......
...@@ -146,7 +146,7 @@ struct TestGatherOpHandle { ...@@ -146,7 +146,7 @@ struct TestGatherOpHandle {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_var, platform::errors::NotFound( in_var, platform::errors::NotFound(
"The variable '%s' is not found in the scope.", "input")); "The variable '%s' is not found in the scope.", "input"));
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>(); auto in_selected_rows = in_var->GetMutable<pten::SelectedRows>();
auto value = in_selected_rows->mutable_value(); auto value = in_selected_rows->mutable_value();
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]); value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
...@@ -162,10 +162,10 @@ struct TestGatherOpHandle { ...@@ -162,10 +162,10 @@ struct TestGatherOpHandle {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound( out_var, platform::errors::NotFound(
"The variable '%s' is not found in the scope.", "out")); "The variable '%s' is not found in the scope.", "out"));
auto out_selected_rows = out_var->GetMutable<f::SelectedRows>(); auto out_selected_rows = out_var->GetMutable<pten::SelectedRows>();
auto in_var = param_scopes_.at(output_scope_idx)->FindVar("input"); auto in_var = param_scopes_.at(output_scope_idx)->FindVar("input");
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>(); auto in_selected_rows = in_var->GetMutable<pten::SelectedRows>();
out_selected_rows->mutable_value()->ShareDataWith( out_selected_rows->mutable_value()->ShareDataWith(
in_selected_rows->value()); in_selected_rows->value());
...@@ -177,7 +177,7 @@ struct TestGatherOpHandle { ...@@ -177,7 +177,7 @@ struct TestGatherOpHandle {
p::CPUPlace cpu_place; p::CPUPlace cpu_place;
auto& out_select_rows = out_var->Get<f::SelectedRows>(); auto& out_select_rows = out_var->Get<pten::SelectedRows>();
auto rt = out_select_rows.value(); auto rt = out_select_rows.value();
PADDLE_ENFORCE_EQ(out_select_rows.height(), height, PADDLE_ENFORCE_EQ(out_select_rows.height(), height,
......
...@@ -321,8 +321,8 @@ void CheckVarHasNanOrInf(const std::string& op_type, ...@@ -321,8 +321,8 @@ void CheckVarHasNanOrInf(const std::string& op_type,
const Tensor* tensor{nullptr}; const Tensor* tensor{nullptr};
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>(); tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
tensor = &var->Get<framework::SelectedRows>().value(); tensor = &var->Get<pten::SelectedRows>().value();
} else { } else {
VLOG(10) << var_name << " var_name need not to check"; VLOG(10) << var_name << " var_name need not to check";
return; return;
...@@ -468,8 +468,8 @@ void PrintNpuVarInfo(const std::string& op_type, const std::string& var_name, ...@@ -468,8 +468,8 @@ void PrintNpuVarInfo(const std::string& op_type, const std::string& var_name,
const Tensor* tensor{nullptr}; const Tensor* tensor{nullptr};
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>(); tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
tensor = &var->Get<framework::SelectedRows>().value(); tensor = &var->Get<pten::SelectedRows>().value();
} else { } else {
VLOG(10) << var_name << " var_name need not to check"; VLOG(10) << var_name << " var_name need not to check";
return; return;
......
...@@ -20,6 +20,11 @@ ...@@ -20,6 +20,11 @@
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
namespace pten {
class SelectedRows;
} // namespace pten
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -96,10 +101,10 @@ struct ReduceBufferData { ...@@ -96,10 +101,10 @@ struct ReduceBufferData {
struct GatherLocalSelectedRowsFunctor { struct GatherLocalSelectedRowsFunctor {
GatherLocalSelectedRowsFunctor( GatherLocalSelectedRowsFunctor(
const std::vector<const SelectedRows *> &src_selected_rows, const std::vector<const pten::SelectedRows *> &src_selected_rows,
const std::vector<platform::Place> &in_places, const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes, const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
const platform::Place &out_place, SelectedRows *dst_selected_rows) const platform::Place &out_place, pten::SelectedRows *dst_selected_rows)
: dev_ctxes_(dev_ctxes), : dev_ctxes_(dev_ctxes),
in_places_(in_places), in_places_(in_places),
out_place_(out_place), out_place_(out_place),
...@@ -147,7 +152,7 @@ struct GatherLocalSelectedRowsFunctor { ...@@ -147,7 +152,7 @@ struct GatherLocalSelectedRowsFunctor {
std::vector<Tensor> in_tensors_; std::vector<Tensor> in_tensors_;
platform::Place out_place_; platform::Place out_place_;
SelectedRows *dst_selected_rows_; pten::SelectedRows *dst_selected_rows_;
}; };
} // namespace details } // namespace details
......
...@@ -114,10 +114,10 @@ void ReduceOpHandle::RunImpl() { ...@@ -114,10 +114,10 @@ void ReduceOpHandle::RunImpl() {
t_out_p = platform::CPUPlace(); t_out_p = platform::CPUPlace();
} }
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<pten::SelectedRows>()) {
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const pten::SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<pten::SelectedRows>(in_var_handles, var_scopes);
const CollectiveContext &collective_context = const CollectiveContext &collective_context =
*CollectiveContext::GetInstance(); *CollectiveContext::GetInstance();
...@@ -130,7 +130,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -130,7 +130,7 @@ void ReduceOpHandle::RunImpl() {
platform::is_cpu_place(t_out_p)) { platform::is_cpu_place(t_out_p)) {
GatherLocalSelectedRowsFunctor functor( GatherLocalSelectedRowsFunctor functor(
in_selected_rows, in_places, dev_ctxes_, t_out_p, in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<pten::SelectedRows>());
WaitInputVarGenerated(); WaitInputVarGenerated();
functor(); functor();
return; return;
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class SelectedRows;
namespace details { namespace details {
struct VarHandle; struct VarHandle;
...@@ -131,11 +130,11 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -131,11 +130,11 @@ struct ReduceOpHandle : public OpHandleBase {
defined PADDLE_WITH_DISTRIBUTE defined PADDLE_WITH_DISTRIBUTE
template <typename DevCtx, typename DataType> template <typename DevCtx, typename DataType>
void GatherSelectedRows( void GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_, const std::vector<const pten::SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places, const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes, const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
VarHandle *out_var_handle, const platform::Place &out_place, VarHandle *out_var_handle, const platform::Place &out_place,
SelectedRows *dst_selecte_rows); pten::SelectedRows *dst_selecte_rows);
#endif #endif
void Wait( void Wait(
......
...@@ -174,7 +174,7 @@ struct TestReduceOpHandle { ...@@ -174,7 +174,7 @@ struct TestReduceOpHandle {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_var, platform::errors::NotFound( in_var, platform::errors::NotFound(
"Variable %s is not found in scope.", "input")); "Variable %s is not found in scope.", "input"));
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>(); auto in_selected_rows = in_var->GetMutable<pten::SelectedRows>();
auto value = in_selected_rows->mutable_value(); auto value = in_selected_rows->mutable_value();
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]); value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
...@@ -190,10 +190,10 @@ struct TestReduceOpHandle { ...@@ -190,10 +190,10 @@ struct TestReduceOpHandle {
PADDLE_ENFORCE_NOT_NULL(out_var, PADDLE_ENFORCE_NOT_NULL(out_var,
platform::errors::NotFound( platform::errors::NotFound(
"Variable %s is not found in scope.", "out")); "Variable %s is not found in scope.", "out"));
auto out_selected_rows = out_var->GetMutable<f::SelectedRows>(); auto out_selected_rows = out_var->GetMutable<pten::SelectedRows>();
auto in_var = param_scopes_[output_scope_idx]->FindVar("input"); auto in_var = param_scopes_[output_scope_idx]->FindVar("input");
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>(); auto in_selected_rows = in_var->GetMutable<pten::SelectedRows>();
out_selected_rows->mutable_value()->ShareDataWith( out_selected_rows->mutable_value()->ShareDataWith(
in_selected_rows->value()); in_selected_rows->value());
...@@ -205,7 +205,7 @@ struct TestReduceOpHandle { ...@@ -205,7 +205,7 @@ struct TestReduceOpHandle {
p::CPUPlace cpu_place; p::CPUPlace cpu_place;
auto &out_select_rows = out_var->Get<f::SelectedRows>(); auto &out_select_rows = out_var->Get<pten::SelectedRows>();
auto rt = out_select_rows.value(); auto rt = out_select_rows.value();
PADDLE_ENFORCE_EQ(out_select_rows.height(), height, PADDLE_ENFORCE_EQ(out_select_rows.height(), height,
......
...@@ -33,9 +33,9 @@ static void GetTensors(Variable *var, ...@@ -33,9 +33,9 @@ static void GetTensors(Variable *var,
std::unordered_set<Tensor *> *tensor_set) { std::unordered_set<Tensor *> *tensor_set) {
if (var->IsType<LoDTensor>() && var->Get<LoDTensor>().IsInitialized()) { if (var->IsType<LoDTensor>() && var->Get<LoDTensor>().IsInitialized()) {
tensor_set->insert(var->GetMutable<LoDTensor>()); tensor_set->insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>() && } else if (var->IsType<pten::SelectedRows>() &&
var->Get<SelectedRows>().value().IsInitialized()) { var->Get<pten::SelectedRows>().value().IsInitialized()) {
tensor_set->insert(var->GetMutable<SelectedRows>()->mutable_value()); tensor_set->insert(var->GetMutable<pten::SelectedRows>()->mutable_value());
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
auto *tensor_arr = var->GetMutable<LoDTensorArray>(); auto *tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto &t : *tensor_arr) { for (auto &t : *tensor_arr) {
......
...@@ -33,8 +33,8 @@ template <typename Func> ...@@ -33,8 +33,8 @@ template <typename Func>
static void VisitVariable(Variable* var, Func* func) { static void VisitVariable(Variable* var, Func* func) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
(*func)(var->GetMutable<LoDTensor>()); (*func)(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
(*func)(var->GetMutable<SelectedRows>()); (*func)(var->GetMutable<pten::SelectedRows>());
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"VisitVariable is not supported for type %s.", "VisitVariable is not supported for type %s.",
...@@ -46,8 +46,8 @@ template <typename Func> ...@@ -46,8 +46,8 @@ template <typename Func>
static void VisitVariable(const Variable& var, Func* func) { static void VisitVariable(const Variable& var, Func* func) {
if (var.IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
(*func)(var.Get<LoDTensor>()); (*func)(var.Get<LoDTensor>());
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
(*func)(var.Get<SelectedRows>()); (*func)(var.Get<pten::SelectedRows>());
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"VisitVariable is not supported for type %s.", ToTypeName(var.Type()))); "VisitVariable is not supported for type %s.", ToTypeName(var.Type())));
...@@ -59,7 +59,7 @@ struct TensorVisitor { ...@@ -59,7 +59,7 @@ struct TensorVisitor {
void operator()(LoDTensor* tensor) { result_ = tensor; } void operator()(LoDTensor* tensor) { result_ = tensor; }
void operator()(SelectedRows* selected_rows) { void operator()(pten::SelectedRows* selected_rows) {
result_ = selected_rows->mutable_value(); result_ = selected_rows->mutable_value();
} }
...@@ -85,8 +85,8 @@ struct ShareDimsAndLoDVisitor { ...@@ -85,8 +85,8 @@ struct ShareDimsAndLoDVisitor {
tensor->Resize(val.dims()); tensor->Resize(val.dims());
} }
void operator()(const SelectedRows& val) { void operator()(const pten::SelectedRows& val) {
auto* selected_rows = trg_->GetMutable<SelectedRows>(); auto* selected_rows = trg_->GetMutable<pten::SelectedRows>();
selected_rows->set_rows(val.rows()); selected_rows->set_rows(val.rows());
selected_rows->set_height(val.height()); selected_rows->set_height(val.height());
selected_rows->mutable_value()->Resize(val.value().dims()); selected_rows->mutable_value()->Resize(val.value().dims());
...@@ -131,8 +131,8 @@ struct EnforceShapeAndDTypeEQVisitor { ...@@ -131,8 +131,8 @@ struct EnforceShapeAndDTypeEQVisitor {
"The layout of the two variables' tensors tensor is not equal.")); "The layout of the two variables' tensors tensor is not equal."));
} }
void operator()(const SelectedRows& src) { void operator()(const pten::SelectedRows& src) {
auto& selected_rows = dst_->Get<SelectedRows>(); auto& selected_rows = dst_->Get<pten::SelectedRows>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
src.place().GetType(), selected_rows.place().GetType(), src.place().GetType(), selected_rows.place().GetType(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
......
...@@ -815,8 +815,8 @@ void DownpourWorker::TrainFiles() { ...@@ -815,8 +815,8 @@ void DownpourWorker::TrainFiles() {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
tensor = var->GetMutable<LoDTensor>(); tensor = var->GetMutable<LoDTensor>();
len = tensor->numel(); len = tensor->numel();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
auto selected_rows = var->GetMutable<SelectedRows>(); auto selected_rows = var->GetMutable<pten::SelectedRows>();
tensor = selected_rows->mutable_value(); tensor = selected_rows->mutable_value();
len = tensor->numel(); len = tensor->numel();
} }
......
...@@ -147,9 +147,10 @@ void DeleteUnusedTensors(const Scope &scope, ...@@ -147,9 +147,10 @@ void DeleteUnusedTensors(const Scope &scope,
VLOG(2) << "Erase variable " << var_name; VLOG(2) << "Erase variable " << var_name;
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder()); garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
garbages.emplace_back( garbages.emplace_back(var->GetMutable<pten::SelectedRows>()
var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder()); ->mutable_value()
->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
auto *lod_tensor_arr = var->GetMutable<LoDTensorArray>(); auto *lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto &t : *lod_tensor_arr) { for (auto &t : *lod_tensor_arr) {
......
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_PSCORE #ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h" #include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#endif #endif
#include "paddle/fluid/framework/rw_lock.h" #include "paddle/pten/core/utils/rw_lock.h"
#include "thrust/pair.h" #include "thrust/pair.h"
// #include "cudf/concurrent_unordered_map.cuh.h" // #include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h" #include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
...@@ -81,7 +81,7 @@ class HashTable { ...@@ -81,7 +81,7 @@ class HashTable {
<< " push value size: " << push_grad_value_size_; << " push value size: " << push_grad_value_size_;
} }
std::unique_ptr<RWLock> rwlock_{nullptr}; std::unique_ptr<pten::RWLock> rwlock_{nullptr};
private: private:
TableContainer<KeyType, ValType>* container_; TableContainer<KeyType, ValType>* container_;
......
...@@ -121,7 +121,7 @@ __global__ void dy_mf_update_kernel(Table* table, ...@@ -121,7 +121,7 @@ __global__ void dy_mf_update_kernel(Table* table,
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) { HashTable<KeyType, ValType>::HashTable(size_t capacity) {
container_ = new TableContainer<KeyType, ValType>(capacity); container_ = new TableContainer<KeyType, ValType>(capacity);
rwlock_.reset(new RWLock); rwlock_.reset(new pten::RWLock);
} }
template <typename KeyType, typename ValType> template <typename KeyType, typename ValType>
......
...@@ -136,7 +136,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -136,7 +136,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
if (!root_var) { if (!root_var) {
continue; continue;
} }
if (root_var->IsType<SelectedRows>()) { if (root_var->IsType<pten::SelectedRows>()) {
continue; continue;
} }
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>(); LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
......
...@@ -259,7 +259,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -259,7 +259,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
auto& var_name = new_ins[var_name_item.first].at(i); auto& var_name = new_ins[var_name_item.first].at(i);
const Tensor* tensor_in; const Tensor* tensor_in;
if (var->IsType<LoDTensor>() || var->IsType<SelectedRows>()) { if (var->IsType<LoDTensor>() || var->IsType<pten::SelectedRows>()) {
tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
tensor_in = tensor_in =
......
...@@ -676,8 +676,9 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -676,8 +676,9 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
operators::reader:: operators::reader::
OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) { OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) {
// do nothing // do nothing
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
TensorRecordStream(*(var->GetMutable<SelectedRows>()->mutable_value())); TensorRecordStream(
*(var->GetMutable<pten::SelectedRows>()->mutable_value()));
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
auto* tensor_arr = var->GetMutable<LoDTensorArray>(); auto* tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& tensor : *tensor_arr) { for (auto& tensor : *tensor_arr) {
......
...@@ -76,10 +76,12 @@ void InterpreterCoreEventGarbageCollector::Add( ...@@ -76,10 +76,12 @@ void InterpreterCoreEventGarbageCollector::Add(
} else if (var->IsType<LoDRankTable>()) { } else if (var->IsType<LoDRankTable>()) {
// TODO(xiongkun03) in old executor, this type of variable is not support // TODO(xiongkun03) in old executor, this type of variable is not support
// eager deletion. so we just leave it here ? // eager deletion. so we just leave it here ?
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
Add(var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder(), Add(var->GetMutable<pten::SelectedRows>()
->mutable_value()
->MoveMemoryHolder(),
event, ctx); event, ctx);
var->GetMutable<SelectedRows>()->mutable_rows()->clear(); var->GetMutable<pten::SelectedRows>()->mutable_rows()->clear();
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
auto* tensor_arr = var->GetMutable<LoDTensorArray>(); auto* tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *tensor_arr) { for (auto& t : *tensor_arr) {
...@@ -132,4 +134,4 @@ void InterpreterCoreEventGarbageCollector::Free( ...@@ -132,4 +134,4 @@ void InterpreterCoreEventGarbageCollector::Free(
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
\ No newline at end of file
...@@ -32,9 +32,11 @@ void InterpreterCoreFastGarbageCollector::Add(Variable* var) { ...@@ -32,9 +32,11 @@ void InterpreterCoreFastGarbageCollector::Add(Variable* var) {
} else if (var->IsType<LoDRankTable>()) { } else if (var->IsType<LoDRankTable>()) {
// TODO(xiongkun03) in old executor, this type of variable is not support // TODO(xiongkun03) in old executor, this type of variable is not support
// eager deletion. so we just leave it here ? // eager deletion. so we just leave it here ?
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
Add(var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder()); Add(var->GetMutable<pten::SelectedRows>()
var->GetMutable<SelectedRows>()->mutable_rows()->clear(); ->mutable_value()
->MoveMemoryHolder());
var->GetMutable<pten::SelectedRows>()->mutable_rows()->clear();
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
auto* tensor_arr = var->GetMutable<LoDTensorArray>(); auto* tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *tensor_arr) { for (auto& t : *tensor_arr) {
......
...@@ -468,8 +468,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -468,8 +468,8 @@ void build_op_func_list(const platform::Place& place,
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
garbages->emplace_back( garbages->emplace_back(
var->GetMutable<LoDTensor>()->MoveMemoryHolder()); var->GetMutable<LoDTensor>()->MoveMemoryHolder());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
garbages->emplace_back(var->GetMutable<SelectedRows>() garbages->emplace_back(var->GetMutable<pten::SelectedRows>()
->mutable_value() ->mutable_value()
->MoveMemoryHolder()); ->MoveMemoryHolder());
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/rw_lock.h" #include "paddle/pten/core/utils/rw_lock.h"
// When in inference scenario, the scopes will not be written by two threads in // When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and // a mean time, but a scope may be read by multiple threads concurrently, and
...@@ -171,9 +171,9 @@ void InterpretercoreInferShapeContext::ShareDim(const std::string& in, ...@@ -171,9 +171,9 @@ void InterpretercoreInferShapeContext::ShareDim(const std::string& in,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in, out)); "The type of input (%s) and output (%s) are inconsistent.", in, out));
if (in_var->IsType<framework::SelectedRows>()) { if (in_var->IsType<pten::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<pten::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>(); auto out_sele_rows = out_var->GetMutable<pten::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows()); out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height()); out_sele_rows->set_height(in_sele_rows.height());
...@@ -392,8 +392,8 @@ DDim InterpretercoreInferShapeContext::GetDim(Variable* var) const { ...@@ -392,8 +392,8 @@ DDim InterpretercoreInferShapeContext::GetDim(Variable* var) const {
var, platform::errors::InvalidArgument("Input variable is nullptr.")); var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<pten::SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Only LoDTensor or SelectedRows support 'GetDim', but input " "Only LoDTensor or SelectedRows support 'GetDim', but input "
...@@ -420,8 +420,8 @@ std::vector<DDim> InterpretercoreInferShapeContext::GetRepeatedDims( ...@@ -420,8 +420,8 @@ std::vector<DDim> InterpretercoreInferShapeContext::GetRepeatedDims(
void InterpretercoreInferShapeContext::SetDim(Variable* var, const DDim& dim) { void InterpretercoreInferShapeContext::SetDim(Variable* var, const DDim& dim) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<pten::SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect LoDTensor or SelectedRows, but received " "Variable type error, expect LoDTensor or SelectedRows, but received "
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/event.h"
#include "paddle/pten/core/utils/rw_lock.h"
// When in inference scenario, the scopes will not be written by two threads in // When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and // a mean time, but a scope may be read by multiple threads concurrently, and
......
...@@ -77,11 +77,11 @@ static DDim GetDimsDebug(const ScopeBase& scope, const std::string& name, ...@@ -77,11 +77,11 @@ static DDim GetDimsDebug(const ScopeBase& scope, const std::string& name,
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
const LoDTensor& tensor = var->Get<LoDTensor>(); const LoDTensor& tensor = var->Get<LoDTensor>();
return tensor.dims(); return tensor.dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
if (get_actual_dim) { if (get_actual_dim) {
return var->Get<SelectedRows>().value().dims(); return var->Get<pten::SelectedRows>().value().dims();
} else { } else {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<pten::SelectedRows>().GetCompleteDims();
} }
} else if (var->IsType<Strings>()) { } else if (var->IsType<Strings>()) {
return DDim({static_cast<int64_t>(var->Get<Strings>().size())}); return DDim({static_cast<int64_t>(var->Get<Strings>().size())});
...@@ -108,8 +108,8 @@ static std::string GetDtype(const ScopeBase& scope, const std::string& name) { ...@@ -108,8 +108,8 @@ static std::string GetDtype(const ScopeBase& scope, const std::string& name) {
return ""; return "";
} }
return DataTypeToString(tensor.type()); return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value(); auto tensor = var->Get<pten::SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited"; return "uninited";
} else { } else {
...@@ -139,8 +139,8 @@ static std::string GetPlace(const ScopeBase& scope, const std::string& name) { ...@@ -139,8 +139,8 @@ static std::string GetPlace(const ScopeBase& scope, const std::string& name) {
return ""; return "";
} }
return to_string(tensor.place()); return to_string(tensor.place());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value(); auto tensor = var->Get<pten::SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited"; return "uninited";
} else { } else {
...@@ -157,8 +157,8 @@ static int GetRowSize(const ScopeBase& scope, const std::string& name) { ...@@ -157,8 +157,8 @@ static int GetRowSize(const ScopeBase& scope, const std::string& name) {
return -1; return -1;
} }
if (var->IsType<SelectedRows>()) { if (var->IsType<pten::SelectedRows>()) {
return var->Get<SelectedRows>().rows().size(); return var->Get<pten::SelectedRows>().rows().size();
} }
return -1; return -1;
...@@ -497,8 +497,8 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -497,8 +497,8 @@ void OperatorBase::GenerateTemporaryNames() {
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) { const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
if (var.IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
return static_cast<const Tensor*>(&(var.Get<LoDTensor>())); return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
return &(var.Get<SelectedRows>().value()); return &(var.Get<pten::SelectedRows>().value());
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows.", "Variable type is %s, expect LoDTensor or SelectedRows.",
...@@ -509,8 +509,8 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) { ...@@ -509,8 +509,8 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) { Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
return var->GetMutable<SelectedRows>()->mutable_value(); return var->GetMutable<pten::SelectedRows>()->mutable_value();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows.", "Variable type is %s, expect LoDTensor or SelectedRows.",
...@@ -741,9 +741,9 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -741,9 +741,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
"The type of input (%s) and output (%s) are inconsistent.", in, "The type of input (%s) and output (%s) are inconsistent.", in,
out)); out));
if (in_var->IsType<framework::SelectedRows>()) { if (in_var->IsType<pten::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<pten::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>(); auto out_sele_rows = out_var->GetMutable<pten::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows()); out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height()); out_sele_rows->set_height(in_sele_rows.height());
...@@ -950,8 +950,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -950,8 +950,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
var, platform::errors::InvalidArgument("Input variable is nullptr.")); var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<pten::SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Only LoDTensor or SelectedRows support 'GetDim', but input " "Only LoDTensor or SelectedRows support 'GetDim', but input "
...@@ -976,8 +976,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -976,8 +976,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
void SetDim(Variable* var, const DDim& dim) { void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<pten::SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect LoDTensor or SelectedRows, but received " "Variable type error, expect LoDTensor or SelectedRows, but received "
...@@ -1646,8 +1646,8 @@ void OperatorWithKernel::ParseInputDataType( ...@@ -1646,8 +1646,8 @@ void OperatorWithKernel::ParseInputDataType(
t = &var->Get<Tensor>(); t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) { } else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>(); t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<pten::SelectedRows>().value());
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
auto t_arr = &var->Get<LoDTensorArray>(); auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) { for (size_t j = 0; j < t_arr->size(); j++) {
...@@ -1728,8 +1728,8 @@ Tensor* OperatorWithKernel::GetTensorFormInputSafely( ...@@ -1728,8 +1728,8 @@ Tensor* OperatorWithKernel::GetTensorFormInputSafely(
t = var->GetMutable<Tensor>(); t = var->GetMutable<Tensor>();
} else if (var->IsType<LoDTensor>()) { } else if (var->IsType<LoDTensor>()) {
t = var->GetMutable<LoDTensor>(); t = var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
t = var->GetMutable<SelectedRows>()->mutable_value(); t = var->GetMutable<pten::SelectedRows>()->mutable_value();
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input variable type in complex type promotion.")); "Unsupported input variable type in complex type promotion."));
......
...@@ -117,7 +117,7 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { ...@@ -117,7 +117,7 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
} }
inline bool VarIsTensor(const Variable& var) { inline bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>(); return var.IsType<LoDTensor>() || var.IsType<pten::SelectedRows>();
} }
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
...@@ -473,7 +473,7 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -473,7 +473,7 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
} }
bool IsSelectedRowsInput(const std::string& name) const override { bool IsSelectedRowsInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::SelectedRows>(); return ctx_.InputVar(name)->IsType<pten::SelectedRows>();
} }
private: private:
......
...@@ -456,7 +456,7 @@ TEST(IndicateVarDataTypeTest, selectedrows) { ...@@ -456,7 +456,7 @@ TEST(IndicateVarDataTypeTest, selectedrows) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("selected_rows_1"); auto* var = scope.Var("selected_rows_1");
var->GetMutable<paddle::framework::SelectedRows>(); var->GetMutable<pten::SelectedRows>();
bool caught = false; bool caught = false;
try { try {
......
...@@ -38,12 +38,12 @@ ...@@ -38,12 +38,12 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" #include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/pten/core/utils/rw_lock.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -75,7 +75,7 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -75,7 +75,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
bool exist = false; bool exist = false;
{ {
AutoRDLock r_guard{&rwlock_}; pten::AutoRDLock r_guard{&rwlock_};
exist = cache_by_address_.count(cur_key_by_address) != 0; exist = cache_by_address_.count(cur_key_by_address) != 0;
// if cannot find graph by address, checkout whether the graph structure // if cannot find graph by address, checkout whether the graph structure
// have been stored in cache. // have been stored in cache.
...@@ -96,13 +96,13 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -96,13 +96,13 @@ const CinnCompiledObject& CinnCompiler::Compile(
std::int64_t compiled_num = real_compiled_num_.fetch_add(1); std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res = auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream); CompileGraph(graph, input_tensors, target, compiled_num, stream);
AutoWRLock w_guard{&rwlock_}; pten::AutoWRLock w_guard{&rwlock_};
if (!cache_by_struct_.count(cur_key_by_struct)) { if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_by_address_[cur_key_by_address] = compiled_res.get(); cache_by_address_[cur_key_by_address] = compiled_res.get();
cache_by_struct_[cur_key_by_struct] = std::move(compiled_res); cache_by_struct_[cur_key_by_struct] = std::move(compiled_res);
} }
} }
AutoRDLock guard{&rwlock_}; pten::AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_by_address_[cur_key_by_address]; const auto& cached_boj = *cache_by_address_[cur_key_by_address];
return cached_boj; return cached_boj;
} }
...@@ -198,7 +198,7 @@ std::string CinnCompiler::ReadableKey( ...@@ -198,7 +198,7 @@ std::string CinnCompiler::ReadableKey(
void CinnCompiler::Clear() { void CinnCompiler::Clear() {
{ {
AutoWRLock guard{&rwlock_}; pten::AutoWRLock guard{&rwlock_};
graphs_.clear(); graphs_.clear();
cache_by_address_.clear(); cache_by_address_.clear();
cache_by_struct_.clear(); cache_by_struct_.clear();
......
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/pten/core/utils/rw_lock.h"
namespace paddle { namespace paddle {
...@@ -102,7 +102,7 @@ class CinnCompiler { ...@@ -102,7 +102,7 @@ class CinnCompiler {
std::unique_ptr<CinnCompiledObject>, CinnCacheKey::Hash> std::unique_ptr<CinnCompiledObject>, CinnCacheKey::Hash>
cache_by_struct_; cache_by_struct_;
std::atomic_int64_t real_compiled_num_{0}; std::atomic_int64_t real_compiled_num_{0};
mutable RWLock rwlock_; mutable pten::RWLock rwlock_;
DISABLE_COPY_AND_ASSIGN(CinnCompiler); DISABLE_COPY_AND_ASSIGN(CinnCompiler);
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if !defined(_WIN32)
#include <pthread.h>
#else
#include <mutex> // NOLINT
#endif // !_WIN32
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
#if !defined(_WIN32)
struct RWLock {
RWLock() { pthread_rwlock_init(&lock_, nullptr); }
~RWLock() { pthread_rwlock_destroy(&lock_); }
inline void RDLock() {
PADDLE_ENFORCE_EQ(
pthread_rwlock_rdlock(&lock_), 0,
platform::errors::External("The pthread failed to acquire read lock."));
}
inline void WRLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0,
platform::errors::External(
"The pthread failed to acquire write lock."));
}
inline void UNLock() {
PADDLE_ENFORCE_EQ(
pthread_rwlock_unlock(&lock_), 0,
platform::errors::External("The pthread failed to unlock."));
}
private:
pthread_rwlock_t lock_;
};
// TODO(paddle-dev): Support RWLock for WIN32 for correctness.
#else
// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive
// In windows, rw_lock seems like a hack. Use empty object and do nothing.
struct RWLock {
// FIXME(minqiyang): use mutex here to do fake lock
inline void RDLock() { mutex_.lock(); }
inline void WRLock() { mutex_.lock(); }
inline void UNLock() { mutex_.unlock(); }
private:
std::mutex mutex_;
};
#endif
class AutoWRLock {
public:
explicit AutoWRLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); }
~AutoWRLock() { UnLock(); }
private:
inline void Lock() { lock_->WRLock(); }
inline void UnLock() { lock_->UNLock(); }
private:
RWLock* lock_;
};
class AutoRDLock {
public:
explicit AutoRDLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); }
~AutoRDLock() { UnLock(); }
private:
inline void Lock() { lock_->RDLock(); }
inline void UnLock() { lock_->UNLock(); }
private:
RWLock* lock_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/rw_lock.h"
#include <gtest/gtest.h>
#include <thread> // NOLINT
namespace f = paddle::framework;
void f1(f::RWLock *lock) {
lock->RDLock();
lock->UNLock();
}
TEST(RWLOCK, read_read) {
f::RWLock lock;
lock.RDLock();
std::thread t1(f1, &lock);
std::thread t2(f1, &lock);
t1.join();
t2.join();
lock.UNLock();
}
void f2(f::RWLock *lock, std::vector<int> *result) {
lock->RDLock();
ASSERT_EQ(result->size(), 0UL);
lock->UNLock();
}
void f3(f::RWLock *lock, std::vector<int> *result) {
lock->WRLock();
result->push_back(1);
lock->UNLock();
}
TEST(RWLOCK, read_write) {
f::RWLock lock;
std::vector<int> result;
lock.RDLock();
std::thread t1(f2, &lock, &result);
t1.join();
std::thread t2(f3, &lock, &result);
std::this_thread::sleep_for(std::chrono::seconds(1));
ASSERT_EQ(result.size(), 0UL);
lock.UNLock();
t2.join();
ASSERT_EQ(result.size(), 1UL);
}
void f4(f::RWLock *lock, std::vector<int> *result) {
lock->RDLock();
ASSERT_EQ(result->size(), 1UL);
lock->UNLock();
}
TEST(RWLOCK, write_read) {
f::RWLock lock;
std::vector<int> result;
lock.WRLock();
std::thread t1(f4, &lock, &result);
std::this_thread::sleep_for(std::chrono::seconds(1));
result.push_back(1);
lock.UNLock();
t1.join();
}
...@@ -34,10 +34,10 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -34,10 +34,10 @@ PADDLE_DEFINE_EXPORTED_bool(
#define SCOPE_VARS_READER_LOCK #define SCOPE_VARS_READER_LOCK
#define SCOPE_VARS_WRITER_LOCK #define SCOPE_VARS_WRITER_LOCK
#else #else
#define SCOPE_KIDS_READER_LOCK AutoRDLock auto_lock(&kids_lock_); #define SCOPE_KIDS_READER_LOCK pten::AutoRDLock auto_lock(&kids_lock_);
#define SCOPE_KIDS_WRITER_LOCK AutoWRLock auto_lock(&kids_lock_); #define SCOPE_KIDS_WRITER_LOCK pten::AutoWRLock auto_lock(&kids_lock_);
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_); #define SCOPE_VARS_READER_LOCK pten::AutoRDLock auto_lock(&vars_lock_);
#define SCOPE_VARS_WRITER_LOCK AutoWRLock auto_lock(&vars_lock_); #define SCOPE_VARS_WRITER_LOCK pten::AutoWRLock auto_lock(&vars_lock_);
#endif #endif
namespace paddle { namespace paddle {
......
...@@ -26,9 +26,9 @@ extern "C" { ...@@ -26,9 +26,9 @@ extern "C" {
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/pten/core/utils/rw_lock.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -194,8 +194,8 @@ class Scope : public ScopeBase { ...@@ -194,8 +194,8 @@ class Scope : public ScopeBase {
#ifndef PADDLE_ON_INFERENCE #ifndef PADDLE_ON_INFERENCE
private: private:
mutable RWLock kids_lock_; mutable pten::RWLock kids_lock_;
mutable RWLock vars_lock_; mutable pten::RWLock vars_lock_;
#endif #endif
}; };
......
...@@ -17,73 +17,8 @@ limitations under the License. */ ...@@ -17,73 +17,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct ReAllocateVisitor { void SerializeToStream(std::ostream& os,
ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor) const pten::SelectedRows& selected_rows,
: dims_(dims), tensor_(tensor) {}
template <typename T>
void operator()() const {
framework::Tensor cpu_tensor;
platform::CPUPlace cpu;
T* ptr = cpu_tensor.mutable_data<T>(dims_, cpu);
const T* old_ptr =
tensor_->memory_size() == 0 ? nullptr : tensor_->data<T>();
if (old_ptr != nullptr) {
std::copy(old_ptr, old_ptr + tensor_->numel(), ptr);
}
tensor_->ShareDataWith(cpu_tensor);
}
framework::DDim dims_;
framework::Tensor* tensor_;
};
struct TensorCopyVisitor {
TensorCopyVisitor(framework::Tensor* dst, int64_t dst_offset,
const framework::Tensor src, int64_t src_offset,
int64_t size)
: dst_(dst),
dst_offset_(dst_offset),
src_(src),
src_offset_(src_offset),
size_(size) {}
template <typename T>
void apply() const {
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu,
src_.data<T>() + src_offset_, size_ * sizeof(T));
}
framework::Tensor* dst_;
int64_t dst_offset_;
framework::Tensor src_;
int64_t src_offset_;
int64_t size_;
};
struct TensorFillVisitor {
TensorFillVisitor(framework::Tensor* dst, int64_t dst_offset, int64_t size,
float value)
: dst_(dst), dst_offset_(dst_offset), size_(size) {}
template <typename T>
void apply() const {
// TODO(qiao): support other place
platform::CPUPlace cpu;
auto* tensor_data = dst_->mutable_data<T>(cpu);
auto* start = tensor_data + dst_offset_;
auto* end = start + size_;
std::fill(start, end, static_cast<T>(0.0));
}
framework::Tensor* dst_;
int64_t dst_offset_;
int64_t size_;
};
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version { // the 1st field, uint32_t version
constexpr uint32_t version = 0; constexpr uint32_t version = 0;
...@@ -107,7 +42,8 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, ...@@ -107,7 +42,8 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
TensorToStream(os, selected_rows.value(), dev_ctx); TensorToStream(os, selected_rows.value(), dev_ctx);
} }
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows) { void SerializeToStream(std::ostream& os,
const pten::SelectedRows& selected_rows) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx; const platform::DeviceContext* dev_ctx;
auto place = selected_rows.place(); auto place = selected_rows.place();
...@@ -115,14 +51,15 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows) { ...@@ -115,14 +51,15 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows) {
SerializeToStream(os, selected_rows, *dev_ctx); SerializeToStream(os, selected_rows, *dev_ctx);
} }
void DeserializeFromStream(std::istream& os, SelectedRows* selected_rows) { void DeserializeFromStream(std::istream& os,
pten::SelectedRows* selected_rows) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx; const platform::DeviceContext* dev_ctx;
dev_ctx = pool.Get(platform::CPUPlace()); dev_ctx = pool.Get(platform::CPUPlace());
DeserializeFromStream(os, selected_rows, *dev_ctx); DeserializeFromStream(os, selected_rows, *dev_ctx);
} }
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, void DeserializeFromStream(std::istream& is, pten::SelectedRows* selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
{ {
// the 1st field, unit32_t version for SelectedRows // the 1st field, unit32_t version for SelectedRows
...@@ -151,109 +88,5 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, ...@@ -151,109 +88,5 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
// the 4st field, tensor which contains the data // the 4st field, tensor which contains the data
TensorFromStream(is, selected_rows->mutable_value(), dev_ctx); TensorFromStream(is, selected_rows->mutable_value(), dev_ctx);
} }
bool SelectedRows::HasKey(int64_t key) const {
return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false
: true;
}
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
bool is_test) {
if (is_test) {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
rwlock_->RDLock();
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
rwlock_->UNLock();
PADDLE_ENFORCE_EQ(
auto_grown, true,
platform::errors::NotFound("Input key(%lld) is not found.", key));
rwlock_->WRLock();
auto map_size = id_to_index_.size();
auto vector_size = rows_.size();
if (map_size != vector_size) {
rwlock_->UNLock();
PADDLE_THROW(platform::errors::InvalidArgument(
"Row map size(%zu) should be equal to rows size(%zu).", map_size,
vector_size));
}
auto write_iter = id_to_index_.find(key);
if (write_iter == id_to_index_.end()) {
int row_num = rows_.size();
if (row_num == value_->dims()[0]) {
rwlock_->UNLock();
PADDLE_THROW(platform::errors::InvalidArgument(
"Selected rows is full, then length exceed the length of first "
"dimension (%d).",
row_num));
}
// key logic to put a key into id_to_index_
rows_.push_back(key);
auto index = static_cast<int64_t>(rows_.size() - 1);
id_to_index_[key] = index;
rwlock_->UNLock();
return index;
} else {
auto index = write_iter->second;
rwlock_->UNLock();
return index;
}
} else {
auto index = iter->second;
rwlock_->UNLock();
return index;
}
}
void SelectedRows::SyncIndex() {
rwlock_->WRLock();
id_to_index_.clear();
for (size_t i = 0; i < rows_.size(); ++i) {
id_to_index_[rows_[i]] = i;
}
rwlock_->UNLock();
}
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown, bool is_test) {
PADDLE_ENFORCE_EQ(value->IsInitialized(), true,
platform::errors::InvalidArgument(
"The value tensor is not initialized."));
if (ids.numel() == 0) {
VLOG(3) << "keys is empty, please check data!";
} else {
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(
value_width, value->numel() / value->dims()[0],
platform::errors::InvalidArgument(
"Output tensor should have the same shape with table "
"except the first dimmension, excepted value width not counting "
"the first dimension is %d, actual value width is %d.",
value_width, value->numel() / value->dims()[0]));
for (int i = 0; i < ids.numel(); ++i) {
auto id = ids.data<int64_t>()[i];
int64_t index = AutoGrownIndex(id, auto_grown, is_test);
if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType(
value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else {
framework::VisitDataType(
value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,153 +21,28 @@ limitations under the License. */ ...@@ -21,153 +21,28 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/pten/core/selected_rows.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class SelectedRows {
/*
* @brief We can use the SelectedRows structure to reproduce a sparse table.
* A sparse table is a key-value structure that the key is an `int64_t`,
* and the value is a Tensor which the first dimension is 0.
* You can use the following interface to operate the sparse table, and you
* can find
* some detail information from the comments of each interface:
*
* HasKey(key), whether the sparse table has the specified key.
* Set(key, value), set a key-value pair into the sparse table.
* Get(keys, value*), get value by given key list and apply it to the given
* value pointer
* with the specified offset.
*
*/
public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new Tensor());
rwlock_.reset(new RWLock);
}
SelectedRows() {
height_ = 0;
value_.reset(new Tensor());
rwlock_.reset(new RWLock);
}
const platform::Place& place() const { return value_->place(); }
const Tensor& value() const { return *value_; }
Tensor* mutable_value() { return value_.get(); }
int64_t height() const { return height_; }
void set_height(int64_t height) { height_ = height; }
const Vector<int64_t>& rows() const { return rows_; }
Vector<int64_t>* mutable_rows() { return &rows_; }
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
/*
* @brief Get the index of key in rows
*
* @return -1 if the key does not exists.
*/
int64_t Index(int64_t key) const {
auto it = std::find(rows_.begin(), rows_.end(), key);
if (it == rows_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Input id (%lld) is not in current rows table.", key));
}
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
/*
* @brief whether has the specified key in the table.
*
* @return true if the key is exists.
*/
bool HasKey(int64_t key) const;
/*
* @brief Get value by the key list.
* Note!!! this interface is only used when selected_rows is used as
* parameters
* for distribute lookup table.
*
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
void Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown = false, bool is_test = false);
/*
* @brief Get the index of the key from id_to_index_ map. If the key not
* exist,
* add the key into id_to_index_.
*
* Note!!! this interface is only used when selected_rows is used as
* parameters
* for distribute lookup table.
*
* @return index of the key.
*/
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
/*
* @brief Get the index of the key from id_to_index_ map.
*/
inline int64_t GetIndexFromId(int64_t key) const {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
void SyncIndex();
/*
* @brief Get complete Dims before
*/
DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_;
return make_ddim(dims);
}
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when rows_ has duplicate member
std::unique_ptr<Tensor> value_{nullptr};
int64_t height_; // height indicates the underline tensor's height
std::unique_ptr<RWLock> rwlock_{nullptr};
};
/* /*
* Serialize/Desiralize SelectedRows to std::ostream * Serialize/Desiralize SelectedRows to std::ostream
* You can pass ofstream or ostringstream to serilize to file * You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU. * or to a in memory string. GPU tensor will be copied to CPU.
*/ */
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, void SerializeToStream(std::ostream& os,
const pten::SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, void DeserializeFromStream(std::istream& is, pten::SelectedRows* selected_rows,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows); void SerializeToStream(std::ostream& os,
const pten::SelectedRows& selected_rows);
void DeserializeFromStream(std::istream& os, SelectedRows* selected_rows); void DeserializeFromStream(std::istream& os, pten::SelectedRows* selected_rows);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -24,7 +24,7 @@ class SelectedRowsTester : public ::testing::Test { ...@@ -24,7 +24,7 @@ class SelectedRowsTester : public ::testing::Test {
std::vector<int64_t> rows{0, 4, 7}; std::vector<int64_t> rows{0, 4, 7};
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 100; int64_t row_numel = 100;
selected_rows_.reset(new SelectedRows(rows, height)); selected_rows_.reset(new pten::SelectedRows(rows, height));
Tensor* value = selected_rows_->mutable_value(); Tensor* value = selected_rows_->mutable_value();
auto* data = value->mutable_data<float>( auto* data = value->mutable_data<float>(
...@@ -36,7 +36,7 @@ class SelectedRowsTester : public ::testing::Test { ...@@ -36,7 +36,7 @@ class SelectedRowsTester : public ::testing::Test {
protected: protected:
platform::CPUPlace place_; platform::CPUPlace place_;
std::unique_ptr<SelectedRows> selected_rows_{nullptr}; std::unique_ptr<pten::SelectedRows> selected_rows_{nullptr};
}; };
TEST_F(SelectedRowsTester, height) { ASSERT_EQ(selected_rows_->height(), 10); } TEST_F(SelectedRowsTester, height) { ASSERT_EQ(selected_rows_->height(), 10); }
...@@ -50,7 +50,7 @@ TEST_F(SelectedRowsTester, complete_dims) { ...@@ -50,7 +50,7 @@ TEST_F(SelectedRowsTester, complete_dims) {
} }
TEST_F(SelectedRowsTester, SerializeAndDeseralize) { TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
SelectedRows dst_tensor; pten::SelectedRows dst_tensor;
platform::CPUDeviceContext cpu_ctx(place_); platform::CPUDeviceContext cpu_ctx(place_);
std::ostringstream oss; std::ostringstream oss;
...@@ -71,7 +71,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ...@@ -71,7 +71,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
TEST(SelectedRows, SparseTable) { TEST(SelectedRows, SparseTable) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
SelectedRows table; pten::SelectedRows table;
int64_t table_size = 100; int64_t table_size = 100;
int64_t embedding_width = 8; int64_t embedding_width = 8;
...@@ -124,7 +124,7 @@ TEST(SelectedRows, SparseTable) { ...@@ -124,7 +124,7 @@ TEST(SelectedRows, SparseTable) {
} }
} }
void f1(SelectedRows* table, int table_size) { void f1(pten::SelectedRows* table, int table_size) {
for (int i = 1000000; i > 0; --i) { for (int i = 1000000; i > 0; --i) {
auto id = i % table_size; auto id = i % table_size;
int64_t index1 = table->AutoGrownIndex(id, true); int64_t index1 = table->AutoGrownIndex(id, true);
...@@ -135,7 +135,7 @@ void f1(SelectedRows* table, int table_size) { ...@@ -135,7 +135,7 @@ void f1(SelectedRows* table, int table_size) {
} }
} }
void f2(SelectedRows* table, int table_size) { void f2(pten::SelectedRows* table, int table_size) {
for (int i = 0; i < 1000000; ++i) { for (int i = 0; i < 1000000; ++i) {
auto id = i % table_size; auto id = i % table_size;
int64_t index1 = table->AutoGrownIndex(id, true); int64_t index1 = table->AutoGrownIndex(id, true);
...@@ -146,7 +146,7 @@ void f2(SelectedRows* table, int table_size) { ...@@ -146,7 +146,7 @@ void f2(SelectedRows* table, int table_size) {
} }
} }
void f3(SelectedRows* table, int table_size) { void f3(pten::SelectedRows* table, int table_size) {
clock_t t1 = clock(); clock_t t1 = clock();
for (int i = 100000; i > 0; --i) { for (int i = 100000; i > 0; --i) {
auto id1 = table->AutoGrownIndex(i % table_size, true); auto id1 = table->AutoGrownIndex(i % table_size, true);
...@@ -157,7 +157,7 @@ void f3(SelectedRows* table, int table_size) { ...@@ -157,7 +157,7 @@ void f3(SelectedRows* table, int table_size) {
std::cout << "f3 run time:" << t2 - t1 << std::endl; std::cout << "f3 run time:" << t2 - t1 << std::endl;
} }
void f4(SelectedRows* table, int table_size) { void f4(pten::SelectedRows* table, int table_size) {
clock_t t1 = clock(); clock_t t1 = clock();
for (int i = 0; i < 100000; ++i) { for (int i = 0; i < 100000; ++i) {
auto id1 = table->AutoGrownIndex(i % table_size, true); auto id1 = table->AutoGrownIndex(i % table_size, true);
...@@ -170,7 +170,7 @@ void f4(SelectedRows* table, int table_size) { ...@@ -170,7 +170,7 @@ void f4(SelectedRows* table, int table_size) {
TEST(SelectedRows, MultiThreadAutoIndex) { TEST(SelectedRows, MultiThreadAutoIndex) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
SelectedRows table; pten::SelectedRows table;
int64_t table_size = 100000; int64_t table_size = 100000;
int64_t embedding_width = 8; int64_t embedding_width = 8;
......
...@@ -57,7 +57,7 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) { ...@@ -57,7 +57,7 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
visitor(var.Get<LoDTensorArray>()); visitor(var.Get<LoDTensorArray>());
return; return;
case proto::VarType::SELECTED_ROWS: case proto::VarType::SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<pten::SelectedRows>());
return; return;
case proto::VarType::READER: case proto::VarType::READER:
visitor(var.Get<ReaderHolder>()); visitor(var.Get<ReaderHolder>());
......
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
class SelectedRows;
} // namespace pten } // namespace pten
// Users should add forward declarations here // Users should add forward declarations here
...@@ -76,7 +77,6 @@ class LoDRankTable; ...@@ -76,7 +77,6 @@ class LoDRankTable;
class ScopeBase; class ScopeBase;
class ReaderHolder; class ReaderHolder;
class Scope; class Scope;
class SelectedRows;
} // namespace framework } // namespace framework
namespace operators { namespace operators {
...@@ -166,7 +166,7 @@ struct VarTypeRegistryImpl { ...@@ -166,7 +166,7 @@ struct VarTypeRegistryImpl {
// Users should add other variable types below. // Users should add other variable types below.
// Paddle would generate unique Ids for each registered variable types. // Paddle would generate unique Ids for each registered variable types.
using VarTypeRegistry = detail::VarTypeRegistryImpl< using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, SelectedRows, std::vector<Scope *>, LoDRankTable, Strings, Tensor, pten::SelectedRows, std::vector<Scope *>, LoDRankTable, Strings,
LoDTensorArray, platform::PlaceList, ReaderHolder, String, Scope *, LoDTensorArray, platform::PlaceList, ReaderHolder, String, Scope *,
operators::reader::LoDTensorBlockingQueueHolder, FetchList, FeedList, operators::reader::LoDTensorBlockingQueueHolder, FetchList, FeedList,
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder, operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder,
...@@ -206,7 +206,7 @@ struct VarTypeTrait { ...@@ -206,7 +206,7 @@ struct VarTypeTrait {
// Users should set some of variable type ids to be what is defined in // Users should set some of variable type ids to be what is defined in
// framework.proto below // framework.proto below
REG_PROTO_VAR_TYPE_TRAIT(LoDTensor, proto::VarType::LOD_TENSOR); REG_PROTO_VAR_TYPE_TRAIT(LoDTensor, proto::VarType::LOD_TENSOR);
REG_PROTO_VAR_TYPE_TRAIT(SelectedRows, proto::VarType::SELECTED_ROWS); REG_PROTO_VAR_TYPE_TRAIT(pten::SelectedRows, proto::VarType::SELECTED_ROWS);
REG_PROTO_VAR_TYPE_TRAIT(std::vector<Scope *>, proto::VarType::STEP_SCOPES); REG_PROTO_VAR_TYPE_TRAIT(std::vector<Scope *>, proto::VarType::STEP_SCOPES);
REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE); REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY); REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
......
...@@ -92,7 +92,7 @@ bool CheckVarId(int proto_id) { ...@@ -92,7 +92,7 @@ bool CheckVarId(int proto_id) {
TEST(var_type_traits, check_proto_type_id) { TEST(var_type_traits, check_proto_type_id) {
ASSERT_TRUE(CheckVarId<LoDTensor>(proto::VarType::LOD_TENSOR)); ASSERT_TRUE(CheckVarId<LoDTensor>(proto::VarType::LOD_TENSOR));
ASSERT_TRUE(CheckVarId<SelectedRows>(proto::VarType::SELECTED_ROWS)); ASSERT_TRUE(CheckVarId<pten::SelectedRows>(proto::VarType::SELECTED_ROWS));
ASSERT_TRUE(CheckVarId<std::vector<Scope *>>(proto::VarType::STEP_SCOPES)); ASSERT_TRUE(CheckVarId<std::vector<Scope *>>(proto::VarType::STEP_SCOPES));
ASSERT_TRUE(CheckVarId<LoDRankTable>(proto::VarType::LOD_RANK_TABLE)); ASSERT_TRUE(CheckVarId<LoDRankTable>(proto::VarType::LOD_RANK_TABLE));
ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY)); ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY));
......
...@@ -123,8 +123,8 @@ inline pten::TensorInplaceVersion* Variable::InplaceVersionCounter() { ...@@ -123,8 +123,8 @@ inline pten::TensorInplaceVersion* Variable::InplaceVersionCounter() {
version_counter_ptr = version_counter_ptr =
&GetMutable<framework::Tensor>()->InplaceVersionCounter(); &GetMutable<framework::Tensor>()->InplaceVersionCounter();
} else if (IsType<framework::SelectedRows>()) { } else if (IsType<pten::SelectedRows>()) {
version_counter_ptr = &GetMutable<framework::SelectedRows>() version_counter_ptr = &GetMutable<pten::SelectedRows>()
->mutable_value() ->mutable_value()
->InplaceVersionCounter(); ->InplaceVersionCounter();
} else { } else {
......
...@@ -31,7 +31,7 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) { ...@@ -31,7 +31,7 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>(); var->GetMutable<pten::SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) { } else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedList>(); var->GetMutable<FeedList>();
} else if (var_type == proto::VarType::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
...@@ -70,9 +70,9 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) { ...@@ -70,9 +70,9 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
auto &src_tensor = src_var.Get<framework::LoDTensor>(); auto &src_tensor = src_var.Get<framework::LoDTensor>();
tmp_grad_tensor->set_lod(src_tensor.lod()); tmp_grad_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor); framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor);
} else if (src_var.IsType<framework::SelectedRows>()) { } else if (src_var.IsType<pten::SelectedRows>()) {
auto &src_slr = src_var.Get<framework::SelectedRows>(); auto &src_slr = src_var.Get<pten::SelectedRows>();
auto *tmp_grad_slr = dst_var->GetMutable<framework::SelectedRows>(); auto *tmp_grad_slr = dst_var->GetMutable<pten::SelectedRows>();
tmp_grad_slr->set_rows(src_slr.rows()); tmp_grad_slr->set_rows(src_slr.rows());
tmp_grad_slr->set_height(src_slr.height()); tmp_grad_slr->set_height(src_slr.height());
auto &src_t = src_slr.value(); auto &src_t = src_slr.value();
......
...@@ -39,8 +39,8 @@ static const platform::Place &GetVarPlace(const framework::Variable &src) { ...@@ -39,8 +39,8 @@ static const platform::Place &GetVarPlace(const framework::Variable &src) {
if (src.IsType<framework::LoDTensor>()) { if (src.IsType<framework::LoDTensor>()) {
return src.Get<framework::LoDTensor>().place(); return src.Get<framework::LoDTensor>().place();
#if NCCL_VERSION_CODE >= 2212 #if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<framework::SelectedRows>()) { } else if (src.IsType<pten::SelectedRows>()) {
return src.Get<framework::SelectedRows>().value().place(); return src.Get<pten::SelectedRows>().value().place();
#endif #endif
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -70,8 +70,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, ...@@ -70,8 +70,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
} }
#if NCCL_VERSION_CODE >= 2212 #if NCCL_VERSION_CODE >= 2212
static void AllReduce(const framework::SelectedRows &src, static void AllReduce(const pten::SelectedRows &src, pten::SelectedRows *dst,
framework::SelectedRows *dst,
const ParallelStrategy &strategy, const ParallelStrategy &strategy,
const gpuStream_t stream, const gpuStream_t stream,
const platform::NCCLComm *comm) { const platform::NCCLComm *comm) {
...@@ -191,19 +190,18 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst, ...@@ -191,19 +190,18 @@ void AllReduce(const framework::Variable &src, framework::Variable *dst,
AllReduce(src.Get<framework::LoDTensor>(), AllReduce(src.Get<framework::LoDTensor>(),
dst->GetMutable<framework::LoDTensor>(), stream, comm); dst->GetMutable<framework::LoDTensor>(), stream, comm);
#if NCCL_VERSION_CODE >= 2212 #if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<framework::SelectedRows>()) { } else if (src.IsType<pten::SelectedRows>()) {
if (&src != dst) { if (&src != dst) {
if (!dst->IsType<framework::SelectedRows>()) { if (!dst->IsType<pten::SelectedRows>()) {
dst->Clear(); dst->Clear();
} }
AllReduce(src.Get<framework::SelectedRows>(), AllReduce(src.Get<pten::SelectedRows>(),
dst->GetMutable<framework::SelectedRows>(), strategy, stream, dst->GetMutable<pten::SelectedRows>(), strategy, stream, comm);
comm);
} else { } else {
// SelectedRows cannot be allreduce in-place // SelectedRows cannot be allreduce in-place
framework::Variable tmp_dst; framework::Variable tmp_dst;
AllReduce(src.Get<framework::SelectedRows>(), AllReduce(src.Get<pten::SelectedRows>(),
tmp_dst.GetMutable<framework::SelectedRows>(), strategy, stream, tmp_dst.GetMutable<pten::SelectedRows>(), strategy, stream,
comm); comm);
// stream must synchronize to ensure accuracy of the move operation // stream must synchronize to ensure accuracy of the move operation
platform::GpuStreamSync(stream); platform::GpuStreamSync(stream);
......
...@@ -365,12 +365,12 @@ class TracedGradOp { ...@@ -365,12 +365,12 @@ class TracedGradOp {
var_wrapper->MutableVar()->CurrentInplaceVersion()) { var_wrapper->MutableVar()->CurrentInplaceVersion()) {
return var_wrapper; return var_wrapper;
} else if (var_wrapper->MutableVar()->IsType<framework::LoDTensor>() || } else if (var_wrapper->MutableVar()->IsType<framework::LoDTensor>() ||
var_wrapper->MutableVar()->IsType<framework::SelectedRows>()) { var_wrapper->MutableVar()->IsType<pten::SelectedRows>()) {
auto* tensor = auto* tensor =
var_wrapper->MutableVar()->IsType<framework::LoDTensor>() var_wrapper->MutableVar()->IsType<framework::LoDTensor>()
? var_wrapper->MutableVar()->GetMutable<framework::LoDTensor>() ? var_wrapper->MutableVar()->GetMutable<framework::LoDTensor>()
: var_wrapper->MutableVar() : var_wrapper->MutableVar()
->GetMutable<framework::SelectedRows>() ->GetMutable<pten::SelectedRows>()
->mutable_value(); ->mutable_value();
if (!tensor->IsInitialized()) { if (!tensor->IsInitialized()) {
return var_wrapper; return var_wrapper;
......
...@@ -72,18 +72,18 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src, ...@@ -72,18 +72,18 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src,
} }
AllReduce(src.Get<framework::LoDTensor>(), AllReduce(src.Get<framework::LoDTensor>(),
dst->GetMutable<framework::LoDTensor>()); dst->GetMutable<framework::LoDTensor>());
} else if (src.IsType<framework::SelectedRows>()) { } else if (src.IsType<pten::SelectedRows>()) {
if (&src != dst) { if (&src != dst) {
if (!dst->IsType<framework::SelectedRows>()) { if (!dst->IsType<pten::SelectedRows>()) {
dst->Clear(); dst->Clear();
} }
AllReduce(src.Get<framework::SelectedRows>(), AllReduce(src.Get<pten::SelectedRows>(),
dst->GetMutable<framework::SelectedRows>()); dst->GetMutable<pten::SelectedRows>());
} else { } else {
// SelectedRows cannot be allreduce in-place // SelectedRows cannot be allreduce in-place
framework::Variable tmp_dst; framework::Variable tmp_dst;
AllReduce(src.Get<framework::SelectedRows>(), AllReduce(src.Get<pten::SelectedRows>(),
tmp_dst.GetMutable<framework::SelectedRows>()); tmp_dst.GetMutable<pten::SelectedRows>());
*dst = std::move(tmp_dst); *dst = std::move(tmp_dst);
} }
} else { } else {
...@@ -120,8 +120,8 @@ void GLOOParallelContext::AllReduce(const framework::Tensor &src_tensor, ...@@ -120,8 +120,8 @@ void GLOOParallelContext::AllReduce(const framework::Tensor &src_tensor,
break; \ break; \
} }
void GLOOParallelContext::AllReduce(const framework::SelectedRows &src, void GLOOParallelContext::AllReduce(const pten::SelectedRows &src,
framework::SelectedRows *dst) { pten::SelectedRows *dst) {
// auto ; // auto ;
// int local_rank = strategy_.local_rank_; // int local_rank = strategy_.local_rank_;
int nranks = strategy_.nranks_; int nranks = strategy_.nranks_;
......
...@@ -59,8 +59,7 @@ class GLOOParallelContext : public ParallelContext { ...@@ -59,8 +59,7 @@ class GLOOParallelContext : public ParallelContext {
private: private:
void AllReduce(const framework::Tensor& src, framework::Tensor* dst); void AllReduce(const framework::Tensor& src, framework::Tensor* dst);
void AllReduce(const framework::SelectedRows& src, void AllReduce(const pten::SelectedRows& src, pten::SelectedRows* dst);
framework::SelectedRows* dst);
private: private:
std::unique_ptr<platform::CPUDeviceContext> device_; std::unique_ptr<platform::CPUDeviceContext> device_;
......
...@@ -55,12 +55,12 @@ static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src, ...@@ -55,12 +55,12 @@ static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src,
auto* dst_tensor = dst->GetMutable<framework::LoDTensor>(); auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor); framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
dst_tensor->set_lod(src_tensor.lod()); dst_tensor->set_lod(src_tensor.lod());
} else if (src->IsType<framework::SelectedRows>()) { } else if (src->IsType<pten::SelectedRows>()) {
auto& src_selected_rows = src->Get<framework::SelectedRows>(); auto& src_selected_rows = src->Get<pten::SelectedRows>();
if (!dst->IsType<framework::SelectedRows>()) { if (!dst->IsType<pten::SelectedRows>()) {
dst->Clear(); dst->Clear();
} }
auto* dst_selected_rows = dst->GetMutable<framework::SelectedRows>(); auto* dst_selected_rows = dst->GetMutable<pten::SelectedRows>();
framework::TensorCopy(src_selected_rows.value(), framework::TensorCopy(src_selected_rows.value(),
src_selected_rows.value().place(), src_selected_rows.value().place(),
dst_selected_rows->mutable_value()); dst_selected_rows->mutable_value());
...@@ -332,7 +332,7 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { ...@@ -332,7 +332,7 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
void SelectedRowsAddToTensor(const framework::Variable& src, void SelectedRowsAddToTensor(const framework::Variable& src,
framework::Variable* dst) { framework::Variable* dst) {
auto* dst_tensor = dst->GetMutable<framework::LoDTensor>(); auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
auto& src_selected_rows = src.Get<framework::SelectedRows>(); auto& src_selected_rows = src.Get<pten::SelectedRows>();
auto place = dst_tensor->place(); auto place = dst_tensor->place();
auto data_type = src_selected_rows.value().type(); auto data_type = src_selected_rows.value().type();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
...@@ -371,7 +371,7 @@ static void SelectedRowsAddTensor( ...@@ -371,7 +371,7 @@ static void SelectedRowsAddTensor(
const framework::Variable& src_tensor_var, const framework::Variable& src_tensor_var,
framework::Variable* dst_tensor_var) { framework::Variable* dst_tensor_var) {
const auto& src_selected_rows = const auto& src_selected_rows =
src_selected_rows_var.Get<framework::SelectedRows>(); src_selected_rows_var.Get<pten::SelectedRows>();
const auto& src_tensor = src_tensor_var.Get<framework::LoDTensor>(); const auto& src_tensor = src_tensor_var.Get<framework::LoDTensor>();
const auto& place = src_tensor.place(); const auto& place = src_tensor.place();
auto data_type = src_tensor.type(); auto data_type = src_tensor.type();
...@@ -414,18 +414,18 @@ static void SelectedRowsAddTensor( ...@@ -414,18 +414,18 @@ static void SelectedRowsAddTensor(
// to one then add it to a empty selected rows, the after is correct // to one then add it to a empty selected rows, the after is correct
std::shared_ptr<VariableWrapper> SelectedRowsMerge( std::shared_ptr<VariableWrapper> SelectedRowsMerge(
const framework::Variable& src1, const framework::Variable& src2) { const framework::Variable& src1, const framework::Variable& src2) {
auto& src_selected_rows1 = src1.Get<framework::SelectedRows>(); auto& src_selected_rows1 = src1.Get<pten::SelectedRows>();
auto& src_selected_rows2 = src2.Get<framework::SelectedRows>(); auto& src_selected_rows2 = src2.Get<pten::SelectedRows>();
auto place = src_selected_rows1.value().place(); auto place = src_selected_rows1.value().place();
auto data_type = src_selected_rows1.value().type(); auto data_type = src_selected_rows1.value().type();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
std::vector<const framework::SelectedRows*> src_selected_rows; std::vector<const pten::SelectedRows*> src_selected_rows;
src_selected_rows.emplace_back(&src_selected_rows1); src_selected_rows.emplace_back(&src_selected_rows1);
src_selected_rows.emplace_back(&src_selected_rows2); src_selected_rows.emplace_back(&src_selected_rows2);
auto dst_var = std::make_shared<VariableWrapper>("Temp"); auto dst_var = std::make_shared<VariableWrapper>("Temp");
auto* dst_selected_rows = auto* dst_selected_rows =
dst_var->MutableVar()->GetMutable<framework::SelectedRows>(); dst_var->MutableVar()->GetMutable<pten::SelectedRows>();
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type) \ #define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type) \
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \ if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
...@@ -463,7 +463,7 @@ void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var, ...@@ -463,7 +463,7 @@ void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
if (dst->IsType<framework::LoDTensor>()) { if (dst->IsType<framework::LoDTensor>()) {
if (src.IsType<framework::LoDTensor>()) { if (src.IsType<framework::LoDTensor>()) {
TensorAdd(src, dst); TensorAdd(src, dst);
} else if (src.IsType<framework::SelectedRows>()) { } else if (src.IsType<pten::SelectedRows>()) {
SelectedRowsAddToTensor(src, dst); SelectedRowsAddToTensor(src, dst);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -481,7 +481,7 @@ void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var, ...@@ -481,7 +481,7 @@ void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
SelectedRowsAddToTensor(*dst, src_mutable); SelectedRowsAddToTensor(*dst, src_mutable);
*dst = std::move(*(var->MutableVar())); *dst = std::move(*(var->MutableVar()));
} }
} else if (src.IsType<framework::SelectedRows>()) { } else if (src.IsType<pten::SelectedRows>()) {
auto temp = SelectedRowsMerge(src, *dst); auto temp = SelectedRowsMerge(src, *dst);
*dst = std::move(*(temp->MutableVar())); *dst = std::move(*(temp->MutableVar()));
} else { } else {
...@@ -497,8 +497,8 @@ static platform::Place GetPlaceOfVar( ...@@ -497,8 +497,8 @@ static platform::Place GetPlaceOfVar(
platform::Place place; platform::Place place;
if (var->Var().IsType<framework::LoDTensor>()) { if (var->Var().IsType<framework::LoDTensor>()) {
place = var->Var().Get<framework::LoDTensor>().place(); place = var->Var().Get<framework::LoDTensor>().place();
} else if (var->Var().IsType<framework::SelectedRows>()) { } else if (var->Var().IsType<pten::SelectedRows>()) {
place = var->Var().Get<framework::SelectedRows>().place(); place = var->Var().Get<pten::SelectedRows>().place();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"only support LoDTensor and SelectedRows in dygraph")); "only support LoDTensor and SelectedRows in dygraph"));
...@@ -530,14 +530,14 @@ void GradientAccumulator::AccumulateGrad() { ...@@ -530,14 +530,14 @@ void GradientAccumulator::AccumulateGrad() {
if (dst->IsType<framework::LoDTensor>()) { if (dst->IsType<framework::LoDTensor>()) {
if (src->IsType<framework::LoDTensor>()) { if (src->IsType<framework::LoDTensor>()) {
TensorAdd(*src, dst); TensorAdd(*src, dst);
} else if (src->IsType<framework::SelectedRows>()) { } else if (src->IsType<pten::SelectedRows>()) {
SelectedRowsAddToTensor(*src, dst); SelectedRowsAddToTensor(*src, dst);
} }
} else if (dst->IsType<framework::SelectedRows>()) { } else if (dst->IsType<pten::SelectedRows>()) {
if (src->IsType<framework::LoDTensor>()) { if (src->IsType<framework::LoDTensor>()) {
SelectedRowsAddToTensor(*dst, src); SelectedRowsAddToTensor(*dst, src);
*dst = std::move(*src); *dst = std::move(*src);
} else if (src->IsType<framework::SelectedRows>()) { } else if (src->IsType<pten::SelectedRows>()) {
auto temp = SelectedRowsMerge(*src, *dst); auto temp = SelectedRowsMerge(*src, *dst);
*dst = std::move(*(temp->MutableVar())); *dst = std::move(*(temp->MutableVar()));
} }
...@@ -657,7 +657,7 @@ void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, ...@@ -657,7 +657,7 @@ void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
// so synchronous VariableWrapper with Variable. // so synchronous VariableWrapper with Variable.
if (dst_var->Var().IsType<framework::LoDTensor>()) { if (dst_var->Var().IsType<framework::LoDTensor>()) {
dst_var->SetType(framework::proto::VarType::LOD_TENSOR); dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (dst_var->Var().IsType<framework::SelectedRows>()) { } else if (dst_var->Var().IsType<pten::SelectedRows>()) {
dst_var->SetType(framework::proto::VarType::SELECTED_ROWS); dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
} }
...@@ -701,7 +701,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, ...@@ -701,7 +701,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
if (paddle::platform::is_gpu_place(place)) { if (paddle::platform::is_gpu_place(place)) {
// sum selected rows firstly // sum selected rows firstly
for (auto& var_info : tmp_grad_vars_) { for (auto& var_info : tmp_grad_vars_) {
if (!var_info.var->Var().IsType<framework::SelectedRows>()) { if (!var_info.var->Var().IsType<pten::SelectedRows>()) {
continue; continue;
} }
...@@ -744,7 +744,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, ...@@ -744,7 +744,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_info.var->Var().IsType<framework::LoDTensor>() || var_info.var->Var().IsType<framework::LoDTensor>() ||
var_info.var->Var().IsType<framework::SelectedRows>(), var_info.var->Var().IsType<pten::SelectedRows>(),
true, platform::errors::PermissionDenied("The type of Gradient " true, platform::errors::PermissionDenied("The type of Gradient "
"var must be LoDTensor " "var must be LoDTensor "
"or SelectedRows")); "or SelectedRows"));
...@@ -789,7 +789,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var, ...@@ -789,7 +789,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
if (dst_var->Var().IsType<framework::LoDTensor>()) { if (dst_var->Var().IsType<framework::LoDTensor>()) {
dst_var->SetType(framework::proto::VarType::LOD_TENSOR); dst_var->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (dst_var->Var().IsType<framework::SelectedRows>()) { } else if (dst_var->Var().IsType<pten::SelectedRows>()) {
dst_var->SetType(framework::proto::VarType::SELECTED_ROWS); dst_var->SetType(framework::proto::VarType::SELECTED_ROWS);
} }
} }
......
...@@ -31,7 +31,7 @@ class GradientAccumulator { ...@@ -31,7 +31,7 @@ class GradientAccumulator {
if (var && var->Var().IsInitialized()) { if (var && var->Var().IsInitialized()) {
if (var->Var().IsType<framework::LoDTensor>()) { if (var->Var().IsType<framework::LoDTensor>()) {
var->SetType(framework::proto::VarType::LOD_TENSOR); var->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (var->Var().IsType<framework::SelectedRows>()) { } else if (var->Var().IsType<pten::SelectedRows>()) {
var->SetType(framework::proto::VarType::SELECTED_ROWS); var->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
......
...@@ -196,8 +196,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -196,8 +196,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>(); auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims()); out_lod_tensor->Resize(in_lod_tensor.dims());
} else { } else {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<pten::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>(); auto out_sele_rows = out_var->GetMutable<pten::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows()); out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height()); out_sele_rows->set_height(in_sele_rows.height());
...@@ -365,8 +365,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -365,8 +365,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
"Input variable should not be null")); "Input variable should not be null"));
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().dims(); return var->Get<framework::LoDTensor>().dims();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
return var->Get<framework::SelectedRows>().GetCompleteDims(); return var->Get<pten::SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Only LoDTensor/SelectedRows support 'GetDim', but Variables " "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
...@@ -382,8 +382,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -382,8 +382,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
void SetDim(framework::Variable* var, const DDim& dim) { void SetDim(framework::Variable* var, const DDim& dim) {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
var->GetMutable<framework::LoDTensor>()->Resize(dim); var->GetMutable<framework::LoDTensor>()->Resize(dim);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->set_height(dim[0]); var->GetMutable<pten::SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Variable type_id %s, expect LoDTensor/SelectedRows.")); "Variable type_id %s, expect LoDTensor/SelectedRows."));
......
...@@ -105,9 +105,9 @@ static std::string DebugString( ...@@ -105,9 +105,9 @@ static std::string DebugString(
ss << "NOT_INITED"; ss << "NOT_INITED";
} }
ss << ">"; ss << ">";
} else if (var.IsType<framework::SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
ss << "SelectedRows<"; ss << "SelectedRows<";
auto& selected_rows = var.Get<framework::SelectedRows>(); auto& selected_rows = var.Get<pten::SelectedRows>();
auto& tensor = selected_rows.value(); auto& tensor = selected_rows.value();
auto& rows = selected_rows.rows(); auto& rows = selected_rows.rows();
if (tensor.IsInitialized()) { if (tensor.IsInitialized()) {
...@@ -188,9 +188,8 @@ size_t VarBase::GradOpNum() const { ...@@ -188,9 +188,8 @@ size_t VarBase::GradOpNum() const {
void VarBase::ClearGradient(bool set_to_zero) { void VarBase::ClearGradient(bool set_to_zero) {
VLOG(4) << "ClearGradient " << Name(); VLOG(4) << "ClearGradient " << Name();
if (grad_var_) { if (grad_var_) {
if (grad_var_->Var().IsType<framework::SelectedRows>()) { if (grad_var_->Var().IsType<pten::SelectedRows>()) {
auto* grad_t = auto* grad_t = grad_var_->MutableVar()->GetMutable<pten::SelectedRows>();
grad_var_->MutableVar()->GetMutable<framework::SelectedRows>();
if (grad_t->mutable_value()->IsInitialized()) { if (grad_t->mutable_value()->IsInitialized()) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) platform::ClearMKLDNNCache(grad_t->place()); if (FLAGS_use_mkldnn) platform::ClearMKLDNNCache(grad_t->place());
...@@ -248,7 +247,7 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -248,7 +247,7 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
const bool blocking) const { const bool blocking) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
Var().IsInitialized() && (Var().IsType<framework::LoDTensor>() || Var().IsInitialized() && (Var().IsType<framework::LoDTensor>() ||
Var().IsType<framework::SelectedRows>()), Var().IsType<pten::SelectedRows>()),
true, platform::errors::InvalidArgument( true, platform::errors::InvalidArgument(
"Variable is not initialized or Variable's type is not " "Variable is not initialized or Variable's type is not "
"LoDTensor or SelectedRows when getting numpy tensor")); "LoDTensor or SelectedRows when getting numpy tensor"));
...@@ -277,12 +276,12 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -277,12 +276,12 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
<< dst_place; << dst_place;
return new_var; return new_var;
} else { } else {
auto& src_selected_rows = Var().Get<framework::SelectedRows>(); auto& src_selected_rows = Var().Get<pten::SelectedRows>();
auto new_var = std::make_shared<VarBase>( auto new_var = std::make_shared<VarBase>(
false, "Itmp" + std::to_string(copied_counter_++)); false, "Itmp" + std::to_string(copied_counter_++));
new_var->SetType(framework::proto::VarType::SELECTED_ROWS); new_var->SetType(framework::proto::VarType::SELECTED_ROWS);
auto* dst_selected_rows = auto* dst_selected_rows =
new_var->MutableVar()->GetMutable<framework::SelectedRows>(); new_var->MutableVar()->GetMutable<pten::SelectedRows>();
framework::TensorCopy(src_selected_rows.value(), dst_place, framework::TensorCopy(src_selected_rows.value(), dst_place,
dst_selected_rows->mutable_value()); dst_selected_rows->mutable_value());
...@@ -346,10 +345,9 @@ void VarBase::CopyFrom(const VarBase& src, const bool blocking) { ...@@ -346,10 +345,9 @@ void VarBase::CopyFrom(const VarBase& src, const bool blocking) {
dst_tensor->Resize(src_tensor.dims()); dst_tensor->Resize(src_tensor.dims());
} }
framework::TensorCopy(src_tensor, place, dst_tensor); framework::TensorCopy(src_tensor, place, dst_tensor);
} else if (src.Var().IsType<framework::SelectedRows>()) { } else if (src.Var().IsType<pten::SelectedRows>()) {
auto& src_selected_rows = src.Var().Get<framework::SelectedRows>(); auto& src_selected_rows = src.Var().Get<pten::SelectedRows>();
auto* dst_selected_rows = auto* dst_selected_rows = MutableVar()->GetMutable<pten::SelectedRows>();
MutableVar()->GetMutable<framework::SelectedRows>();
dst_selected_rows->set_height(src_selected_rows.height()); dst_selected_rows->set_height(src_selected_rows.height());
dst_selected_rows->set_rows(src_selected_rows.rows()); dst_selected_rows->set_rows(src_selected_rows.rows());
......
...@@ -47,8 +47,8 @@ const std::shared_ptr<VariableWrapper>& GetVariableWrapper( ...@@ -47,8 +47,8 @@ const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
if (var.IsType<framework::LoDTensor>()) { if (var.IsType<framework::LoDTensor>()) {
return &(var.Get<framework::LoDTensor>()); return &(var.Get<framework::LoDTensor>());
} else if (var.IsType<framework::SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
return &(var.Get<framework::SelectedRows>().value()); return &(var.Get<pten::SelectedRows>().value());
} else { } else {
return nullptr; return nullptr;
} }
......
...@@ -36,8 +36,7 @@ namespace imperative { ...@@ -36,8 +36,7 @@ namespace imperative {
void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
framework::Tensor *tensor = framework::Tensor *tensor =
is_sparse_ is_sparse_
? sparse_contents_->GetMutable<framework::SelectedRows>() ? sparse_contents_->GetMutable<pten::SelectedRows>()->mutable_value()
->mutable_value()
: dense_contents_.GetMutable<framework::LoDTensor>(); : dense_contents_.GetMutable<framework::LoDTensor>();
if (platform::is_gpu_place(tensor->place())) { if (platform::is_gpu_place(tensor->place())) {
...@@ -775,7 +774,7 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { ...@@ -775,7 +774,7 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
auto var_base = vars_[var_index]->GradVarBase(); auto var_base = vars_[var_index]->GradVarBase();
// need to check tensor type // need to check tensor type
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_base->Var().IsType<framework::SelectedRows>(), true, var_base->Var().IsType<pten::SelectedRows>(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The sparse parameter[%d][%s] must have a selectedrows gradient. " "The sparse parameter[%d][%s] must have a selectedrows gradient. "
"Before forward pass, the parameter type is inferred to be " "Before forward pass, the parameter type is inferred to be "
...@@ -995,8 +994,8 @@ bool Reducer::HasGrad(size_t var_index) { ...@@ -995,8 +994,8 @@ bool Reducer::HasGrad(size_t var_index) {
if (var.Get<framework::LoDTensor>().IsInitialized()) { if (var.Get<framework::LoDTensor>().IsInitialized()) {
return true; return true;
} }
} else if (var.IsType<framework::SelectedRows>()) { } else if (var.IsType<pten::SelectedRows>()) {
if (var.Get<framework::SelectedRows>().value().IsInitialized()) { if (var.Get<pten::SelectedRows>().value().IsInitialized()) {
return true; return true;
} }
} else { } else {
......
...@@ -124,8 +124,8 @@ static void CopyVar(const framework::Variable& var, ...@@ -124,8 +124,8 @@ static void CopyVar(const framework::Variable& var,
auto* dst_tensor = dst.GetMutable<framework::LoDTensor>(); auto* dst_tensor = dst.GetMutable<framework::LoDTensor>();
framework::TensorCopySync(src_tensor, src_tensor.place(), dst_tensor); framework::TensorCopySync(src_tensor, src_tensor.place(), dst_tensor);
} else { } else {
const auto& src_selected_rows = var.Get<framework::SelectedRows>(); const auto& src_selected_rows = var.Get<pten::SelectedRows>();
auto* dst_selected_rows = dst.GetMutable<framework::SelectedRows>(); auto* dst_selected_rows = dst.GetMutable<pten::SelectedRows>();
dst_selected_rows->set_rows(src_selected_rows.rows()); dst_selected_rows->set_rows(src_selected_rows.rows());
dst_selected_rows->set_height(src_selected_rows.height()); dst_selected_rows->set_height(src_selected_rows.height());
framework::TensorCopySync(src_selected_rows.value(), framework::TensorCopySync(src_selected_rows.value(),
...@@ -148,8 +148,8 @@ static bool IsEqualVar(const framework::Variable& var1, ...@@ -148,8 +148,8 @@ static bool IsEqualVar(const framework::Variable& var1,
framework::TensorCopySync(var2.Get<framework::LoDTensor>(), framework::TensorCopySync(var2.Get<framework::LoDTensor>(),
platform::CPUPlace(), &t2); platform::CPUPlace(), &t2);
} else { } else {
auto& s1 = var1.Get<framework::SelectedRows>(); auto& s1 = var1.Get<pten::SelectedRows>();
auto& s2 = var2.Get<framework::SelectedRows>(); auto& s2 = var2.Get<pten::SelectedRows>();
if (s1.height() != s2.height()) { if (s1.height() != s2.height()) {
return false; return false;
...@@ -166,9 +166,9 @@ static bool IsEqualVar(const framework::Variable& var1, ...@@ -166,9 +166,9 @@ static bool IsEqualVar(const framework::Variable& var1,
return false; return false;
} }
framework::TensorCopySync(var1.Get<framework::SelectedRows>().value(), framework::TensorCopySync(var1.Get<pten::SelectedRows>().value(),
platform::CPUPlace(), &t1); platform::CPUPlace(), &t1);
framework::TensorCopySync(var2.Get<framework::SelectedRows>().value(), framework::TensorCopySync(var2.Get<pten::SelectedRows>().value(),
platform::CPUPlace(), &t2); platform::CPUPlace(), &t2);
} }
...@@ -211,7 +211,7 @@ static framework::Variable RandomSelectedRows(framework::DDim dims, ...@@ -211,7 +211,7 @@ static framework::Variable RandomSelectedRows(framework::DDim dims,
dims[0] = row_number; dims[0] = row_number;
framework::Variable ret; framework::Variable ret;
auto* sr = ret.GetMutable<framework::SelectedRows>(); auto* sr = ret.GetMutable<pten::SelectedRows>();
auto tensor_var = RandomTensor<T>(dims, place, low, high); auto tensor_var = RandomTensor<T>(dims, place, low, high);
sr->mutable_value()->ShareDataWith( sr->mutable_value()->ShareDataWith(
tensor_var.template Get<framework::LoDTensor>()); tensor_var.template Get<framework::LoDTensor>());
......
...@@ -237,7 +237,7 @@ TEST(test_layer, test_debug_string) { ...@@ -237,7 +237,7 @@ TEST(test_layer, test_debug_string) {
std::shared_ptr<imperative::VarBase> selected_rows( std::shared_ptr<imperative::VarBase> selected_rows(
new imperative::VarBase(false, "selected_rows")); new imperative::VarBase(false, "selected_rows"));
auto tensor_sr = selected_rows->MutableVar() auto tensor_sr = selected_rows->MutableVar()
->GetMutable<framework::SelectedRows>() ->GetMutable<pten::SelectedRows>()
->mutable_value(); ->mutable_value();
std::string res_ui_sr = test_func(selected_rows); std::string res_ui_sr = test_func(selected_rows);
ASSERT_TRUE(res_ui_sr.find("NOT_INITED") != std::string::npos); ASSERT_TRUE(res_ui_sr.find("NOT_INITED") != std::string::npos);
......
...@@ -101,7 +101,7 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var); ...@@ -101,7 +101,7 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
TEST(test_prepare_op, test_get_tensor_from_var) { TEST(test_prepare_op, test_get_tensor_from_var) {
std::shared_ptr<imperative::VarBase> vout_error( std::shared_ptr<imperative::VarBase> vout_error(
new imperative::VarBase(false, "vout_error")); new imperative::VarBase(false, "vout_error"));
vout_error->MutableVar()->GetMutable<framework::SelectedRows>(); vout_error->MutableVar()->GetMutable<pten::SelectedRows>();
auto* ts = GetTensorFromVar(*vout_error->MutableVar()); auto* ts = GetTensorFromVar(*vout_error->MutableVar());
ASSERT_TRUE(ts != nullptr); ASSERT_TRUE(ts != nullptr);
} }
......
...@@ -104,8 +104,8 @@ class VariableWrapper { ...@@ -104,8 +104,8 @@ class VariableWrapper {
const framework::Tensor* tensor = nullptr; const framework::Tensor* tensor = nullptr;
if (var_.IsType<framework::LoDTensor>()) { if (var_.IsType<framework::LoDTensor>()) {
tensor = &(var_.Get<framework::LoDTensor>()); tensor = &(var_.Get<framework::LoDTensor>());
} else if (var_.IsType<framework::SelectedRows>()) { } else if (var_.IsType<pten::SelectedRows>()) {
tensor = &(var_.Get<framework::SelectedRows>().value()); tensor = &(var_.Get<pten::SelectedRows>().value());
} else { } else {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Only support LoDTensor and SelectedRows for gradient var")); "Only support LoDTensor and SelectedRows for gradient var"));
...@@ -153,7 +153,7 @@ class VariableWrapper { ...@@ -153,7 +153,7 @@ class VariableWrapper {
if (type_ == framework::proto::VarType::LOD_TENSOR) { if (type_ == framework::proto::VarType::LOD_TENSOR) {
tensor = &(var_.Get<framework::LoDTensor>()); tensor = &(var_.Get<framework::LoDTensor>());
} else if (type_ == framework::proto::VarType::SELECTED_ROWS) { } else if (type_ == framework::proto::VarType::SELECTED_ROWS) {
tensor = &(var_.Get<framework::SelectedRows>().value()); tensor = &(var_.Get<pten::SelectedRows>().value());
} else if (type_ == framework::proto::VarType::VOCAB) { } else if (type_ == framework::proto::VarType::VOCAB) {
const framework::Vocab* data = nullptr; const framework::Vocab* data = nullptr;
data = &(var_.Get<framework::Vocab>()); data = &(var_.Get<framework::Vocab>());
...@@ -193,7 +193,7 @@ class VariableWrapper { ...@@ -193,7 +193,7 @@ class VariableWrapper {
if (type_ == framework::proto::VarType::LOD_TENSOR) { if (type_ == framework::proto::VarType::LOD_TENSOR) {
tensor = &(var_.Get<framework::LoDTensor>()); tensor = &(var_.Get<framework::LoDTensor>());
} else if (type_ == framework::proto::VarType::SELECTED_ROWS) { } else if (type_ == framework::proto::VarType::SELECTED_ROWS) {
tensor = &(var_.Get<framework::SelectedRows>().value()); tensor = &(var_.Get<pten::SelectedRows>().value());
} else { } else {
VLOG(6) << "Variable " << name_ << " is not initialized"; VLOG(6) << "Variable " << name_ << " is not initialized";
return place; return place;
......
...@@ -43,7 +43,7 @@ struct TensorArrayBatchCleaner { ...@@ -43,7 +43,7 @@ struct TensorArrayBatchCleaner {
constexpr auto kLoDTensorId = constexpr auto kLoDTensorId =
framework::VarTypeTrait<framework::LoDTensor>::kId; framework::VarTypeTrait<framework::LoDTensor>::kId;
constexpr auto kSelectedRowsId = constexpr auto kSelectedRowsId =
framework::VarTypeTrait<framework::SelectedRows>::kId; framework::VarTypeTrait<pten::SelectedRows>::kId;
constexpr auto kFetchListId = constexpr auto kFetchListId =
framework::VarTypeTrait<framework::FetchList>::kId; framework::VarTypeTrait<framework::FetchList>::kId;
valid_types_.insert(kTensorId); valid_types_.insert(kTensorId);
......
...@@ -50,9 +50,8 @@ class AssignFunctor { ...@@ -50,9 +50,8 @@ class AssignFunctor {
} }
} }
void operator()(const framework::SelectedRows &rows) const { void operator()(const pten::SelectedRows &rows) const {
framework::SelectedRows &out_rows = pten::SelectedRows &out_rows = *out_->GetMutable<pten::SelectedRows>();
*out_->GetMutable<framework::SelectedRows>();
out_rows.set_rows(rows.rows()); out_rows.set_rows(rows.rows());
out_rows.set_height(rows.height()); out_rows.set_height(rows.height());
auto &t = rows.value(); auto &t = rows.value();
......
...@@ -87,7 +87,7 @@ TEST(AssignOp, AssignSelectedRows) { ...@@ -87,7 +87,7 @@ TEST(AssignOp, AssignSelectedRows) {
std::vector<int64_t> rows{0, 4, 7}; std::vector<int64_t> rows{0, 4, 7};
int64_t height = 10; int64_t height = 10;
paddle::framework::SelectedRows input(rows, height); pten::SelectedRows input(rows, height);
paddle::framework::Tensor* input_tensor = input.mutable_value(); paddle::framework::Tensor* input_tensor = input.mutable_value();
paddle::framework::DDim in_dims = paddle::framework::make_ddim({3, 4}); paddle::framework::DDim in_dims = paddle::framework::make_ddim({3, 4});
...@@ -98,7 +98,7 @@ TEST(AssignOp, AssignSelectedRows) { ...@@ -98,7 +98,7 @@ TEST(AssignOp, AssignSelectedRows) {
assign_functor(input); assign_functor(input);
auto& out_selected_row = output.Get<paddle::framework::SelectedRows>(); auto& out_selected_row = output.Get<pten::SelectedRows>();
const paddle::framework::Vector<int64_t>& out_rows = out_selected_row.rows(); const paddle::framework::Vector<int64_t>& out_rows = out_selected_row.rows();
EXPECT_EQ(rows.size(), out_rows.size()); EXPECT_EQ(rows.size(), out_rows.size());
for (size_t i = 0; i < rows.size(); ++i) { for (size_t i = 0; i < rows.size(); ++i) {
......
...@@ -36,21 +36,22 @@ class ClipByNormKernel<platform::CUDADeviceContext, platform::float16> ...@@ -36,21 +36,22 @@ class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
output = context.Output<Tensor>("Out"); output = context.Output<Tensor>("Out");
output->mutable_data<platform::float16>(context.GetPlace()); output->mutable_data<platform::float16>(context.GetPlace());
} else if (in_var->IsType<SelectedRows>()) { } else if (in_var->IsType<pten::SelectedRows>()) {
auto* x = context.Input<SelectedRows>("X"); auto* x = context.Input<pten::SelectedRows>("X");
// merge ids in selected rows first // merge ids in selected rows first
math::scatter::MergeAdd<platform::CUDADeviceContext, platform::float16> math::scatter::MergeAdd<platform::CUDADeviceContext, platform::float16>
merge_func; merge_func;
SelectedRows* merged_input = pten::SelectedRows* merged_input =
const_cast<framework::Scope&>(context.scope()) const_cast<framework::Scope&>(context.scope())
.Var() .Var()
->GetMutable<SelectedRows>(); ->GetMutable<pten::SelectedRows>();
merge_func(context.template device_context<platform::CUDADeviceContext>(), merge_func(context.template device_context<platform::CUDADeviceContext>(),
*x, merged_input); *x, merged_input);
input = &(merged_input->value()); input = &(merged_input->value());
SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out"); pten::SelectedRows* output_selected_rows =
context.Output<pten::SelectedRows>("Out");
output_selected_rows->set_rows(merged_input->rows()); output_selected_rows->set_rows(merged_input->rows());
output_selected_rows->set_height(merged_input->height()); output_selected_rows->set_height(merged_input->height());
output = output_selected_rows->mutable_value(); output = output_selected_rows->mutable_value();
......
...@@ -24,7 +24,7 @@ namespace paddle { ...@@ -24,7 +24,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows; // using SelectedRows = pten::SelectedRows;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
...@@ -43,20 +43,21 @@ class ClipByNormKernel : public framework::OpKernel<T> { ...@@ -43,20 +43,21 @@ class ClipByNormKernel : public framework::OpKernel<T> {
output = context.Output<Tensor>("Out"); output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
} else if (in_var->IsType<SelectedRows>()) { } else if (in_var->IsType<pten::SelectedRows>()) {
auto* x = context.Input<SelectedRows>("X"); auto* x = context.Input<pten::SelectedRows>("X");
// merge ids in selected rows first // merge ids in selected rows first
math::scatter::MergeAdd<DeviceContext, T> merge_func; math::scatter::MergeAdd<DeviceContext, T> merge_func;
SelectedRows* merged_input = pten::SelectedRows* merged_input =
const_cast<framework::Scope&>(context.scope()) const_cast<framework::Scope&>(context.scope())
.Var() .Var()
->GetMutable<SelectedRows>(); ->GetMutable<pten::SelectedRows>();
merge_func(context.template device_context<DeviceContext>(), *x, merge_func(context.template device_context<DeviceContext>(), *x,
merged_input); merged_input);
input = &(merged_input->value()); input = &(merged_input->value());
SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out"); pten::SelectedRows* output_selected_rows =
context.Output<pten::SelectedRows>("Out");
output_selected_rows->set_rows(merged_input->rows()); output_selected_rows->set_rows(merged_input->rows());
output_selected_rows->set_height(merged_input->height()); output_selected_rows->set_height(merged_input->height());
output = output_selected_rows->mutable_value(); output = output_selected_rows->mutable_value();
......
...@@ -113,9 +113,9 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -113,9 +113,9 @@ class ClipKernel : public framework::OpKernel<T> {
trans(context.template device_context<DeviceContext>(), x_data, trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max)); x_data + numel, out_data, ClipFunctor<T>(min, max));
} }
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<pten::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X"); auto* x = context.Input<pten::SelectedRows>("X");
auto* out = context.Output<framework::SelectedRows>("Out"); auto* out = context.Output<pten::SelectedRows>("Out");
PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument( PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument(
"Inplace clip is not allowed " "Inplace clip is not allowed "
"when x is SelectedRows")); "when x is SelectedRows"));
......
...@@ -32,7 +32,7 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T> ...@@ -32,7 +32,7 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
ctx.InputName("X"))); ctx.InputName("X")));
const auto& cuda_ctx = const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
if (x_var->IsType<framework::SelectedRows>()) { if (x_var->IsType<pten::SelectedRows>()) {
framework::Tensor x_for_selectedrows; framework::Tensor x_for_selectedrows;
std::vector<const framework::Tensor*> ins; std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs; std::vector<framework::Tensor*> outs;
......
...@@ -92,20 +92,20 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -92,20 +92,20 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto* y = ctx.Input<framework::LoDTensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
framework::Tensor x, *z; framework::Tensor x, *z;
if (x_var->IsType<framework::SelectedRows>()) { if (x_var->IsType<pten::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true, PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be " "For elementwise_op, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %s.", "scalar. But reveived the size of Y = %s.",
y->dims().size())); y->dims().size()));
auto& x_sele = x_var->Get<framework::SelectedRows>(); auto& x_sele = x_var->Get<pten::SelectedRows>();
auto out_sele = ctx.Output<framework::SelectedRows>("Out"); auto out_sele = ctx.Output<pten::SelectedRows>("Out");
x = x_sele.value(); x = x_sele.value();
out_sele->set_rows(x_sele.rows()); out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height()); out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims()); out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type()); out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value(); z = ctx.Output<pten::SelectedRows>("Out")->mutable_value();
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x.dims() == y->dims(); auto dims_equal = x.dims() == y->dims();
if (dims_equal) { if (dims_equal) {
......
...@@ -84,7 +84,7 @@ int PackTensorsIntoVector(const framework::ExecutionContext &ctx, ...@@ -84,7 +84,7 @@ int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
auto *x = ctx.Input<framework::LoDTensor>("X"); auto *x = ctx.Input<framework::LoDTensor>("X");
z = ctx.Output<framework::LoDTensor>("Out"); z = ctx.Output<framework::LoDTensor>("Out");
ins->emplace_back(x); ins->emplace_back(x);
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<pten::SelectedRows>()) {
PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true, PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"For elementwise_op, if X is Sparse, Y must be " "For elementwise_op, if X is Sparse, Y must be "
...@@ -96,15 +96,15 @@ int PackTensorsIntoVector(const framework::ExecutionContext &ctx, ...@@ -96,15 +96,15 @@ int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
"The parameter x_for_selectedrows is excepted to " "The parameter x_for_selectedrows is excepted to "
"be valid, once input varible X`s class type is " "be valid, once input varible X`s class type is "
"SelectedRows.\n")); "SelectedRows.\n"));
auto &x_sele = x_var->Get<framework::SelectedRows>(); auto &x_sele = x_var->Get<pten::SelectedRows>();
auto out_sele = ctx.Output<framework::SelectedRows>("Out"); auto out_sele = ctx.Output<pten::SelectedRows>("Out");
*x_for_selectedrows = x_sele.value(); *x_for_selectedrows = x_sele.value();
out_sele->set_rows(x_sele.rows()); out_sele->set_rows(x_sele.rows());
out_sele->set_height(x_sele.height()); out_sele->set_height(x_sele.height());
out_sele->mutable_value()->Resize(x_sele.value().dims()); out_sele->mutable_value()->Resize(x_sele.value().dims());
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), out_sele->mutable_value()->mutable_data(ctx.GetPlace(),
x_for_selectedrows->type()); x_for_selectedrows->type());
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value(); z = ctx.Output<pten::SelectedRows>("Out")->mutable_value();
ins->emplace_back(x_for_selectedrows); ins->emplace_back(x_for_selectedrows);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
......
...@@ -117,7 +117,7 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -117,7 +117,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
const auto& str_value = ctx.Attr<std::string>("str_value"); const auto& str_value = ctx.Attr<std::string>("str_value");
value = str_value.empty() ? "value" : "str_value"; value = str_value.empty() ? "value" : "str_value";
} }
if (!ctx.OutputVar("Out")->IsType<framework::SelectedRows>()) { if (!ctx.OutputVar("Out")->IsType<pten::SelectedRows>()) {
return framework::KernelSignature("full", {}, {shape, value}, {"Out"}); return framework::KernelSignature("full", {}, {shape, value}, {"Out"});
} }
return framework::KernelSignature("fill_constant.unregistered", {}, {}, {}); return framework::KernelSignature("fill_constant.unregistered", {}, {}, {});
......
...@@ -92,8 +92,8 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -92,8 +92,8 @@ class FillConstantKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->Resize(shape); tensor->Resize(shape);
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<pten::SelectedRows>()) {
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var->GetMutable<pten::SelectedRows>()->mutable_value();
tensor->Resize(shape); tensor->Resize(shape);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = pten::SelectedRows;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
......
...@@ -30,7 +30,7 @@ namespace operators { ...@@ -30,7 +30,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = pten::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1; constexpr int64_t kNoPadding = -1;
...@@ -200,8 +200,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -200,8 +200,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
DDim table_dim; DDim table_dim;
if (table_var->IsType<LoDTensor>()) { if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims(); table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) { } else if (table_var->IsType<pten::SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W"); auto *table_t = context.Input<pten::SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
...@@ -215,7 +215,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -215,7 +215,8 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
if (is_sparse) { if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
// runtime shape // runtime shape
d_table->set_height(table_dim[0]); d_table->set_height(table_dim[0]);
......
...@@ -57,7 +57,7 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
class GetTensorFromSelectedRowsKernel { class GetTensorFromSelectedRowsKernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.Input<framework::SelectedRows>("X"); auto *x = ctx.Input<pten::SelectedRows>("X");
auto *out = ctx.Output<framework::LoDTensor>("Out"); auto *out = ctx.Output<framework::LoDTensor>("Out");
out->Resize(x->value().dims()); out->Resize(x->value().dims());
......
...@@ -204,7 +204,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -204,7 +204,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
"Custom tree must be set for sparse mode!")); "Custom tree must be set for sparse mode!"));
framework::Vector<int64_t> real_rows = PathToRows(*path); framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad = auto* w_grad =
ctx.Output<framework::SelectedRows>(framework::GradVarName("W")); ctx.Output<pten::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows); w_grad->set_rows(real_rows);
// Build a map of id -> row_index to speed up finding the index of one id // Build a map of id -> row_index to speed up finding the index of one id
w_grad->set_height(w.dims()[0]); w_grad->set_height(w.dims()[0]);
......
...@@ -55,8 +55,8 @@ class OverflowOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,8 @@ class OverflowOp : public framework::OperatorWithKernel {
auto *x_var = ctx.InputVar("X"); auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) { if (x_var->IsType<framework::LoDTensor>()) {
dtype = x_var->Get<framework::LoDTensor>().type(); dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<pten::SelectedRows>()) {
dtype = x_var->Get<framework::SelectedRows>().value().type(); dtype = x_var->Get<pten::SelectedRows>().value().type();
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
true, false, true, false,
......
...@@ -58,8 +58,8 @@ class OverflowKernel : public framework::OpKernel<T> { ...@@ -58,8 +58,8 @@ class OverflowKernel : public framework::OpKernel<T> {
if (x->IsType<framework::LoDTensor>()) { if (x->IsType<framework::LoDTensor>()) {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
functor(*in, out); functor(*in, out);
} else if (x->IsType<framework::SelectedRows>()) { } else if (x->IsType<pten::SelectedRows>()) {
auto& in = ctx.Input<framework::SelectedRows>("X")->value(); auto& in = ctx.Input<pten::SelectedRows>("X")->value();
functor(in, out); functor(in, out);
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -62,8 +62,8 @@ class OverflowV2Op : public framework::OperatorWithKernel { ...@@ -62,8 +62,8 @@ class OverflowV2Op : public framework::OperatorWithKernel {
auto *x_var = ctx.InputVar("X"); auto *x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) { if (x_var->IsType<framework::LoDTensor>()) {
dtype = x_var->Get<framework::LoDTensor>().type(); dtype = x_var->Get<framework::LoDTensor>().type();
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<pten::SelectedRows>()) {
dtype = x_var->Get<framework::SelectedRows>().value().type(); dtype = x_var->Get<pten::SelectedRows>().value().type();
} else { } else {
PADDLE_THROW(plat::errors::InvalidArgument( PADDLE_THROW(plat::errors::InvalidArgument(
"Cannot find the input data type by all input data")); "Cannot find the input data type by all input data"));
......
...@@ -50,7 +50,7 @@ class LoadOpKernel : public framework::OpKernel<T> { ...@@ -50,7 +50,7 @@ class LoadOpKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var, ctx); LoadLodTensor(fin, place, out_var, ctx);
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<pten::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var); LoadSelectedRows(fin, place, out_var);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -105,7 +105,7 @@ class LoadOpKernel : public framework::OpKernel<T> { ...@@ -105,7 +105,7 @@ class LoadOpKernel : public framework::OpKernel<T> {
void LoadSelectedRows(std::istream &fin, const platform::Place &place, void LoadSelectedRows(std::istream &fin, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>(); auto *selectedRows = var->GetMutable<pten::SelectedRows>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
......
...@@ -29,7 +29,7 @@ namespace operators { ...@@ -29,7 +29,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = pten::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
template <typename T> template <typename T>
......
...@@ -151,7 +151,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -151,7 +151,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *table = context.Input<LoDTensor>("W"); auto *table = context.Input<LoDTensor>("W");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids->numel();
......
...@@ -28,7 +28,7 @@ namespace operators { ...@@ -28,7 +28,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = pten::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1; constexpr int64_t kNoPadding = -1;
...@@ -82,8 +82,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -82,8 +82,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
} }
} }
} else if (table_var->IsType<SelectedRows>()) { } else if (table_var->IsType<pten::SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>(); const auto &table_t = table_var->Get<pten::SelectedRows>();
int64_t row_width = table_t.value().dims()[1]; int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>(); const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
...@@ -155,8 +155,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -155,8 +155,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
DDim table_dim; DDim table_dim;
if (table_var->IsType<LoDTensor>()) { if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims(); table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) { } else if (table_var->IsType<pten::SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W"); auto *table_t = context.Input<pten::SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -171,7 +171,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -171,7 +171,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
if (is_sparse) { if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids->numel();
......
...@@ -152,7 +152,8 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> { ...@@ -152,7 +152,8 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *table = context.Input<LoDTensor>("W"); auto *table = context.Input<LoDTensor>("W");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids->numel();
......
...@@ -29,7 +29,7 @@ namespace operators { ...@@ -29,7 +29,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = pten::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1; constexpr int64_t kNoPadding = -1;
...@@ -86,8 +86,8 @@ class LookupTableV2Kernel : public framework::OpKernel<T> { ...@@ -86,8 +86,8 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
} else if (table_var->IsType<SelectedRows>()) { } else if (table_var->IsType<pten::SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>(); const auto &table_t = table_var->Get<pten::SelectedRows>();
int64_t row_width = table_t.value().dims()[1]; int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>(); const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
...@@ -132,8 +132,8 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> { ...@@ -132,8 +132,8 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
DDim table_dim; DDim table_dim;
if (table_var->IsType<LoDTensor>()) { if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims(); table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) { } else if (table_var->IsType<pten::SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W"); auto *table_t = context.Input<pten::SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -148,7 +148,8 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> { ...@@ -148,7 +148,8 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
if (is_sparse) { if (is_sparse) {
auto *ids_t = context.Input<LoDTensor>("Ids"); auto *ids_t = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table =
context.Output<pten::SelectedRows>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel(); int64_t ids_num = ids_t->numel();
std::vector<int64_t> ids; std::vector<int64_t> ids;
......
...@@ -227,11 +227,11 @@ template <typename T> ...@@ -227,11 +227,11 @@ template <typename T>
struct MatrixBitCodeFunctorMulGradWeightSR struct MatrixBitCodeFunctorMulGradWeightSR
: public boost::static_visitor<void> { : public boost::static_visitor<void> {
const framework::Tensor &tmat_; const framework::Tensor &tmat_;
framework::SelectedRows *weight_; pten::SelectedRows *weight_;
const framework::Tensor &input_; const framework::Tensor &input_;
MatrixBitCodeFunctorMulGradWeightSR(const framework::Tensor &tmat, MatrixBitCodeFunctorMulGradWeightSR(const framework::Tensor &tmat,
framework::SelectedRows *weight, pten::SelectedRows *weight,
const framework::Tensor &input) const framework::Tensor &input)
: tmat_(tmat), weight_(weight), input_(input) {} : tmat_(tmat), weight_(weight), input_(input) {}
...@@ -274,7 +274,7 @@ struct MatrixBitCodeFunctorMulGradWeightSR ...@@ -274,7 +274,7 @@ struct MatrixBitCodeFunctorMulGradWeightSR
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat, void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat,
framework::SelectedRows *weight, pten::SelectedRows *weight,
const framework::Tensor &input) { const framework::Tensor &input) {
MatrixBitCodeFunctorMulGradWeightSR<T> func(tmat, weight, input); MatrixBitCodeFunctorMulGradWeightSR<T> func(tmat, weight, input);
code_table_.apply_visitor(func); code_table_.apply_visitor(func);
......
...@@ -252,8 +252,7 @@ class MatrixBitCodeFunctor { ...@@ -252,8 +252,7 @@ class MatrixBitCodeFunctor {
/* For SelectedRows Weight, For index(i, j) >= 0: /* For SelectedRows Weight, For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i) weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/ */
void MulGradWeight(const framework::Tensor& tmat, void MulGradWeight(const framework::Tensor& tmat, pten::SelectedRows* weight,
framework::SelectedRows* weight,
const framework::Tensor& input); const framework::Tensor& input);
/* For j < code_length /* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j)) input.row(i) += tmat(i, j) * weight.row(index(i, j))
......
...@@ -24,9 +24,9 @@ namespace math { ...@@ -24,9 +24,9 @@ namespace math {
template <typename T> template <typename T>
struct SelectedRowsAdd<platform::CPUDeviceContext, T> { struct SelectedRowsAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1,
const framework::SelectedRows& input2, const pten::SelectedRows& input2,
framework::SelectedRows* output) { pten::SelectedRows* output) {
auto in1_height = input1.height(); auto in1_height = input1.height();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in1_height, input2.height(), in1_height, input2.height(),
...@@ -94,7 +94,7 @@ template struct SelectedRowsAdd<platform::CPUDeviceContext, double>; ...@@ -94,7 +94,7 @@ template struct SelectedRowsAdd<platform::CPUDeviceContext, double>;
template <typename T> template <typename T>
struct SelectedRowsAddTensor<platform::CPUDeviceContext, T> { struct SelectedRowsAddTensor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) { const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height(); auto in1_height = input1.height();
auto in2_dims = input2.dims(); auto in2_dims = input2.dims();
...@@ -154,9 +154,8 @@ template struct SelectedRowsAddTensor<platform::CPUDeviceContext, double>; ...@@ -154,9 +154,8 @@ template struct SelectedRowsAddTensor<platform::CPUDeviceContext, double>;
template <typename T> template <typename T>
struct SelectedRowsAddTo<platform::CPUDeviceContext, T> { struct SelectedRowsAddTo<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1, const int64_t input2_offset,
const int64_t input2_offset, pten::SelectedRows* input2) {
framework::SelectedRows* input2) {
auto in1_height = input1.height(); auto in1_height = input1.height();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in1_height, input2->height(), in1_height, input2->height(),
...@@ -198,9 +197,9 @@ template struct SelectedRowsAddTo<platform::CPUDeviceContext, int64_t>; ...@@ -198,9 +197,9 @@ template struct SelectedRowsAddTo<platform::CPUDeviceContext, int64_t>;
template <typename T> template <typename T>
struct SelectedRowsSumTo<platform::CPUDeviceContext, T> { struct SelectedRowsSumTo<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::SelectedRows*>& input1, const std::vector<pten::SelectedRows*>& input1,
const std::vector<int64_t>& input2_offsets, const std::vector<int64_t>& input2_offsets,
framework::SelectedRows* input2) { pten::SelectedRows* input2) {
// Ensure all selected rows have the same height // Ensure all selected rows have the same height
size_t size = 0u; size_t size = 0u;
for (auto iter = input1.begin(); iter != input1.end(); ++iter) { for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
...@@ -242,8 +241,7 @@ template struct SelectedRowsSumTo<platform::CPUDeviceContext, double>; ...@@ -242,8 +241,7 @@ template struct SelectedRowsSumTo<platform::CPUDeviceContext, double>;
template <typename T> template <typename T>
struct SelectedRowsAddToTensor<platform::CPUDeviceContext, T> { struct SelectedRowsAddToTensor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1, framework::Tensor* input2) {
framework::Tensor* input2) {
if (UNLIKELY(input1.rows().size() == 0)) { if (UNLIKELY(input1.rows().size() == 0)) {
LOG(WARNING) << "input selected rows is empty!"; LOG(WARNING) << "input selected rows is empty!";
return; return;
...@@ -313,7 +311,7 @@ typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to( ...@@ -313,7 +311,7 @@ typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
template <typename T> template <typename T>
typename std::enable_if<std::is_same<T, platform::bfloat16>::value>::type typename std::enable_if<std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs, add_sparse_inputs(const std::vector<const pten::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id, const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width, int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) { const platform::CPUDeviceContext& context, T* out_data) {
...@@ -347,7 +345,7 @@ add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs, ...@@ -347,7 +345,7 @@ add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs,
template <typename T> template <typename T>
typename std::enable_if<!std::is_same<T, platform::bfloat16>::value>::type typename std::enable_if<!std::is_same<T, platform::bfloat16>::value>::type
add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs, add_sparse_inputs(const std::vector<const pten::SelectedRows*>& inputs,
const std::unordered_map<int64_t, size_t>& rows_to_id, const std::unordered_map<int64_t, size_t>& rows_to_id,
int64_t input_width, int64_t input_width,
const platform::CPUDeviceContext& context, T* out_data) { const platform::CPUDeviceContext& context, T* out_data) {
...@@ -371,32 +369,31 @@ add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs, ...@@ -371,32 +369,31 @@ add_sparse_inputs(const std::vector<const framework::SelectedRows*>& inputs,
template <typename T> template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> { struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context, pten::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input,
const bool sorted_result = false) { const bool sorted_result = false) {
framework::SelectedRows out; pten::SelectedRows out;
(*this)(context, input, &out, sorted_result); (*this)(context, input, &out, sorted_result);
return out; return out;
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input, pten::SelectedRows* output,
framework::SelectedRows* output,
const bool sorted_result = false) { const bool sorted_result = false) {
std::vector<const framework::SelectedRows*> inputs; std::vector<const pten::SelectedRows*> inputs;
inputs.push_back(&input); inputs.push_back(&input);
(*this)(context, inputs, output, sorted_result); (*this)(context, inputs, output, sorted_result);
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const pten::SelectedRows*>& inputs,
framework::SelectedRows* output, pten::SelectedRows* output,
const bool sorted_result = false) { const bool sorted_result = false) {
if (inputs.size() == 0) { if (inputs.size() == 0) {
VLOG(3) << "no input! return"; VLOG(3) << "no input! return";
return; return;
} }
const framework::SelectedRows* has_value_input = nullptr; const pten::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) { for (auto* in : inputs) {
if (in->rows().size() > 0) { if (in->rows().size() > 0) {
has_value_input = in; has_value_input = in;
...@@ -409,7 +406,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -409,7 +406,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
} }
auto input_width = has_value_input->value().dims()[1]; auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height(); auto input_height = has_value_input->height();
framework::SelectedRows& out = *output; pten::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
size_t row_num = 0; size_t row_num = 0;
for (auto* input : inputs) { for (auto* input : inputs) {
...@@ -480,24 +477,23 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -480,24 +477,23 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
template <typename T> template <typename T>
struct MergeAdd<platform::XPUDeviceContext, T> { struct MergeAdd<platform::XPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::XPUDeviceContext& context, pten::SelectedRows operator()(const platform::XPUDeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input,
const bool sorted_result = false) { const bool sorted_result = false) {
framework::SelectedRows out; pten::SelectedRows out;
(*this)(context, input, &out, sorted_result); (*this)(context, input, &out, sorted_result);
return out; return out;
} }
void operator()(const platform::XPUDeviceContext& context, void operator()(const platform::XPUDeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input, pten::SelectedRows* output,
framework::SelectedRows* output,
const bool sorted_result = false) { const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows()); framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) { if (input_rows.size() == 0) {
return; return;
} }
framework::SelectedRows& out = *output; pten::SelectedRows& out = *output;
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1]; auto input_width = input.value().dims()[1];
...@@ -537,14 +533,14 @@ struct MergeAdd<platform::XPUDeviceContext, T> { ...@@ -537,14 +533,14 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
} }
void operator()(const platform::XPUDeviceContext& context, void operator()(const platform::XPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const pten::SelectedRows*>& inputs,
framework::SelectedRows* output, pten::SelectedRows* output,
const bool sorted_result = false) { const bool sorted_result = false) {
if (inputs.size() == 0) { if (inputs.size() == 0) {
VLOG(3) << "no input! return"; VLOG(3) << "no input! return";
return; return;
} }
const framework::SelectedRows* has_value_input = nullptr; const pten::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) { for (auto* in : inputs) {
if (in->rows().size() > 0) { if (in->rows().size() > 0) {
has_value_input = in; has_value_input = in;
...@@ -557,7 +553,7 @@ struct MergeAdd<platform::XPUDeviceContext, T> { ...@@ -557,7 +553,7 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
} }
auto input_width = has_value_input->value().dims()[1]; auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height(); auto input_height = has_value_input->height();
framework::SelectedRows& out = *output; pten::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
size_t row_num = 0; size_t row_num = 0;
for (auto* input : inputs) { for (auto* input : inputs) {
...@@ -628,29 +624,28 @@ struct MergeAdd<platform::XPUDeviceContext, T> { ...@@ -628,29 +624,28 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
#endif #endif
template <typename T> template <typename T>
struct MergeAverage<platform::CPUDeviceContext, T> { struct MergeAverage<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context, pten::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) { const pten::SelectedRows& input) {
framework::SelectedRows out; pten::SelectedRows out;
(*this)(context, input, &out); (*this)(context, input, &out);
return out; return out;
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input, pten::SelectedRows* output) {
framework::SelectedRows* output) { std::vector<const pten::SelectedRows*> inputs;
std::vector<const framework::SelectedRows*> inputs;
inputs.push_back(&input); inputs.push_back(&input);
(*this)(context, inputs, output); (*this)(context, inputs, output);
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const pten::SelectedRows*>& inputs,
framework::SelectedRows* output) { pten::SelectedRows* output) {
if (inputs.size() == 0) { if (inputs.size() == 0) {
VLOG(3) << "no input! return"; VLOG(3) << "no input! return";
return; return;
} }
const framework::SelectedRows* has_value_input = nullptr; const pten::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) { for (auto* in : inputs) {
if (in->rows().size() > 0) { if (in->rows().size() > 0) {
has_value_input = in; has_value_input = in;
...@@ -663,7 +658,7 @@ struct MergeAverage<platform::CPUDeviceContext, T> { ...@@ -663,7 +658,7 @@ struct MergeAverage<platform::CPUDeviceContext, T> {
} }
auto input_width = has_value_input->value().dims()[1]; auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height(); auto input_height = has_value_input->height();
framework::SelectedRows& out = *output; pten::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
size_t row_num = 0; size_t row_num = 0;
for (auto* input : inputs) { for (auto* input : inputs) {
...@@ -750,7 +745,7 @@ template struct MergeAverage<platform::CPUDeviceContext, double>; ...@@ -750,7 +745,7 @@ template struct MergeAverage<platform::CPUDeviceContext, double>;
template <typename T> template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> { struct UpdateToTensor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const ScatterOps& op, const framework::SelectedRows& input1, const ScatterOps& op, const pten::SelectedRows& input1,
framework::Tensor* input2) { framework::Tensor* input2) {
auto in1_height = input1.height(); auto in1_height = input1.height();
auto in2_dims = input2->dims(); auto in2_dims = input2->dims();
......
...@@ -26,9 +26,9 @@ namespace math { ...@@ -26,9 +26,9 @@ namespace math {
template <typename T> template <typename T>
struct SelectedRowsAdd<platform::CUDADeviceContext, T> { struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1,
const framework::SelectedRows& input2, const pten::SelectedRows& input2,
framework::SelectedRows* output) { pten::SelectedRows* output) {
auto in1_height = input1.height(); auto in1_height = input1.height();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in1_height, input2.height(), in1_height, input2.height(),
...@@ -117,7 +117,7 @@ __global__ void SelectedRowsAddTensorKernel(const T* selected_rows, ...@@ -117,7 +117,7 @@ __global__ void SelectedRowsAddTensorKernel(const T* selected_rows,
template <typename T> template <typename T>
struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> { struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) { const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height(); auto in1_height = input1.height();
auto in2_dims = input2.dims(); auto in2_dims = input2.dims();
...@@ -182,9 +182,8 @@ template struct SelectedRowsAddTensor<platform::CUDADeviceContext, ...@@ -182,9 +182,8 @@ template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
template <typename T> template <typename T>
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> { struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1, const int64_t input2_offset,
const int64_t input2_offset, pten::SelectedRows* input2) {
framework::SelectedRows* input2) {
auto in1_height = input1.height(); auto in1_height = input1.height();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in1_height, input2->height(), in1_height, input2->height(),
...@@ -250,8 +249,7 @@ __global__ void SelectedRowsAddToTensorKernel(const T* selected_rows, ...@@ -250,8 +249,7 @@ __global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
template <typename T> template <typename T>
struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> { struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1, framework::Tensor* input2) {
framework::Tensor* input2) {
auto in1_height = input1.height(); auto in1_height = input1.height();
auto in2_dims = input2->dims(); auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -320,24 +318,23 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, ...@@ -320,24 +318,23 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template <typename T> template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> { struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context, pten::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input,
const bool sorted_result = false) { const bool sorted_result = false) {
framework::SelectedRows out; pten::SelectedRows out;
(*this)(context, input, &out); (*this)(context, input, &out);
return out; return out;
} }
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input, pten::SelectedRows* output,
framework::SelectedRows* output,
const bool sorted_result = false) { const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows()); framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) { if (input_rows.size() == 0) {
return; return;
} }
framework::SelectedRows& out = *output; pten::SelectedRows& out = *output;
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
framework::Vector<int64_t> merge_rows(merge_rows_cpu); framework::Vector<int64_t> merge_rows(merge_rows_cpu);
...@@ -368,14 +365,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -368,14 +365,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
} }
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const pten::SelectedRows*>& inputs,
framework::SelectedRows* output, pten::SelectedRows* output,
const bool sorted_result = false) { const bool sorted_result = false) {
if (inputs.size() == 0) { if (inputs.size() == 0) {
VLOG(3) << "no input! return"; VLOG(3) << "no input! return";
return; return;
} }
const framework::SelectedRows* has_value_input = nullptr; const pten::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) { for (auto* in : inputs) {
if (in->rows().size() > 0) { if (in->rows().size() > 0) {
has_value_input = in; has_value_input = in;
...@@ -388,7 +385,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -388,7 +385,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
} }
auto input_width = has_value_input->value().dims()[1]; auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height(); auto input_height = has_value_input->height();
framework::SelectedRows& out = *output; pten::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().size() == 0) { if (input->rows().size() == 0) {
...@@ -499,7 +496,7 @@ __global__ void UpdateToTensorKernel(const T* selected_rows, ...@@ -499,7 +496,7 @@ __global__ void UpdateToTensorKernel(const T* selected_rows,
template <typename T> template <typename T>
struct UpdateToTensor<platform::CUDADeviceContext, T> { struct UpdateToTensor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const ScatterOps& op, const framework::SelectedRows& input1, const ScatterOps& op, const pten::SelectedRows& input1,
framework::Tensor* input2) { framework::Tensor* input2) {
// NOTE: Use SelectedRowsAddToTensor for better performance // NOTE: Use SelectedRowsAddToTensor for better performance
// no additional MergeAdd called. // no additional MergeAdd called.
......
...@@ -35,15 +35,14 @@ namespace math { ...@@ -35,15 +35,14 @@ namespace math {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SelectedRowsAdd { struct SelectedRowsAdd {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1,
const framework::SelectedRows& input2, const pten::SelectedRows& input2, pten::SelectedRows* output);
framework::SelectedRows* output);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SelectedRowsAddTensor { struct SelectedRowsAddTensor {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output); const framework::Tensor& input2, framework::Tensor* output);
}; };
...@@ -51,17 +50,17 @@ struct SelectedRowsAddTensor { ...@@ -51,17 +50,17 @@ struct SelectedRowsAddTensor {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SelectedRowsAddTo { struct SelectedRowsAddTo {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1, const int64_t input2_offset,
const int64_t input2_offset, framework::SelectedRows* input2); pten::SelectedRows* input2);
}; };
// input2 = [all input in input1] + input2 // input2 = [all input in input1] + input2
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SelectedRowsSumTo { struct SelectedRowsSumTo {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<framework::SelectedRows*>& input1, const std::vector<pten::SelectedRows*>& input1,
const std::vector<int64_t>& input2_offsets, const std::vector<int64_t>& input2_offsets,
framework::SelectedRows* input2); pten::SelectedRows* input2);
}; };
// FIXME: The result of SelectedRowsAddToTensor maybe non deterministic, // FIXME: The result of SelectedRowsAddToTensor maybe non deterministic,
...@@ -70,8 +69,7 @@ struct SelectedRowsSumTo { ...@@ -70,8 +69,7 @@ struct SelectedRowsSumTo {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SelectedRowsAddToTensor { struct SelectedRowsAddToTensor {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input1, const pten::SelectedRows& input1, framework::Tensor* input2);
framework::Tensor* input2);
}; };
namespace scatter { namespace scatter {
...@@ -80,29 +78,25 @@ template <typename DeviceContext, typename T> ...@@ -80,29 +78,25 @@ template <typename DeviceContext, typename T>
struct MergeAdd { struct MergeAdd {
// unary functor, merge by adding duplicated rows in // unary functor, merge by adding duplicated rows in
// the input SelectedRows object. // the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context, pten::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input, const pten::SelectedRows& input,
const bool sorted_result = false); const bool sorted_result = false);
void operator()(const DeviceContext& context, const pten::SelectedRows& input,
pten::SelectedRows* output, const bool sorted_result = false);
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input, const std::vector<const pten::SelectedRows*>& inputs,
framework::SelectedRows* output, pten::SelectedRows* output, const bool sorted_result = false);
const bool sorted_result = false);
void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output,
const bool sorted_result = false);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct MergeAverage { struct MergeAverage {
framework::SelectedRows operator()(const DeviceContext& context, pten::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input); const pten::SelectedRows& input);
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context, const pten::SelectedRows& input,
const framework::SelectedRows& input, pten::SelectedRows* output);
framework::SelectedRows* output);
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const pten::SelectedRows*>& inputs,
framework::SelectedRows* output); pten::SelectedRows* output);
}; };
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
...@@ -111,8 +105,7 @@ enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; ...@@ -111,8 +105,7 @@ enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct UpdateToTensor { struct UpdateToTensor {
void operator()(const DeviceContext& context, const ScatterOps& op, void operator()(const DeviceContext& context, const ScatterOps& op,
const framework::SelectedRows& input1, const pten::SelectedRows& input1, framework::Tensor* input2);
framework::Tensor* input2);
}; };
} // namespace scatter } // namespace scatter
......
...@@ -27,8 +27,8 @@ TEST(selected_rows_functor, cpu_add) { ...@@ -27,8 +27,8 @@ TEST(selected_rows_functor, cpu_add) {
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7}; std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -37,8 +37,8 @@ TEST(selected_rows_functor, cpu_add) { ...@@ -37,8 +37,8 @@ TEST(selected_rows_functor, cpu_add) {
functor(ctx, in1_value, 1.0); functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -46,8 +46,7 @@ TEST(selected_rows_functor, cpu_add) { ...@@ -46,8 +46,7 @@ TEST(selected_rows_functor, cpu_add) {
cpu_place); cpu_place);
functor(ctx, in2_value, 2.0); functor(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
auto* out_value = output->mutable_value(); auto* out_value = output->mutable_value();
// simplely concat two SelectedRows // simplely concat two SelectedRows
...@@ -130,8 +129,8 @@ TEST(selected_rows_functor, cpu_add_to) { ...@@ -130,8 +129,8 @@ TEST(selected_rows_functor, cpu_add_to) {
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7}; std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -140,8 +139,8 @@ TEST(selected_rows_functor, cpu_add_to) { ...@@ -140,8 +139,8 @@ TEST(selected_rows_functor, cpu_add_to) {
functor(ctx, in1_value, 1.0); functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -149,8 +148,7 @@ TEST(selected_rows_functor, cpu_add_to) { ...@@ -149,8 +148,7 @@ TEST(selected_rows_functor, cpu_add_to) {
cpu_place); cpu_place);
functor(ctx, in2_value, 2.0); functor(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
output->set_height(height); output->set_height(height);
auto* out_value = output->mutable_value(); auto* out_value = output->mutable_value();
...@@ -230,8 +228,8 @@ TEST(selected_rows_functor, cpu_merge_average_float) { ...@@ -230,8 +228,8 @@ TEST(selected_rows_functor, cpu_merge_average_float) {
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows{0, 4, 4, 7}; std::vector<int64_t> rows{0, 4, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows{ std::unique_ptr<pten::SelectedRows> selected_rows{
new paddle::framework::SelectedRows(rows, height)}; new pten::SelectedRows(rows, height)};
auto* in_value = selected_rows->mutable_value(); auto* in_value = selected_rows->mutable_value();
in_value->mutable_data<float>( in_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -242,8 +240,7 @@ TEST(selected_rows_functor, cpu_merge_average_float) { ...@@ -242,8 +240,7 @@ TEST(selected_rows_functor, cpu_merge_average_float) {
paddle::operators::math::scatter::MergeAverage< paddle::operators::math::scatter::MergeAverage<
paddle::platform::CPUDeviceContext, float> paddle::platform::CPUDeviceContext, float>
merge_average_functor; merge_average_functor;
paddle::framework::SelectedRows output = pten::SelectedRows output = merge_average_functor(ctx, *selected_rows);
merge_average_functor(ctx, *selected_rows);
auto out_height = output.height(); auto out_height = output.height();
EXPECT_EQ(out_height, height); EXPECT_EQ(out_height, height);
...@@ -270,8 +267,8 @@ TEST(selected_rows_functor, cpu_merge_add_float) { ...@@ -270,8 +267,8 @@ TEST(selected_rows_functor, cpu_merge_add_float) {
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows{0, 4, 4, 7}; std::vector<int64_t> rows{0, 4, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows{ std::unique_ptr<pten::SelectedRows> selected_rows{
new paddle::framework::SelectedRows(rows, height)}; new pten::SelectedRows(rows, height)};
auto* in_value = selected_rows->mutable_value(); auto* in_value = selected_rows->mutable_value();
in_value->mutable_data<float>( in_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -279,8 +276,7 @@ TEST(selected_rows_functor, cpu_merge_add_float) { ...@@ -279,8 +276,7 @@ TEST(selected_rows_functor, cpu_merge_add_float) {
cpu_place); cpu_place);
functor(ctx, in_value, 1.0); functor(ctx, in_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float> float>
...@@ -311,8 +307,8 @@ TEST(selected_rows_functor, cpu_merge_add_int) { ...@@ -311,8 +307,8 @@ TEST(selected_rows_functor, cpu_merge_add_int) {
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows{0, 4, 4, 7}; std::vector<int64_t> rows{0, 4, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows{ std::unique_ptr<pten::SelectedRows> selected_rows{
new paddle::framework::SelectedRows(rows, height)}; new pten::SelectedRows(rows, height)};
auto* in_value = selected_rows->mutable_value(); auto* in_value = selected_rows->mutable_value();
in_value->mutable_data<int>( in_value->mutable_data<int>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -320,8 +316,7 @@ TEST(selected_rows_functor, cpu_merge_add_int) { ...@@ -320,8 +316,7 @@ TEST(selected_rows_functor, cpu_merge_add_int) {
cpu_place); cpu_place);
functor(ctx, in_value, 1); functor(ctx, in_value, 1);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
int> int>
...@@ -354,8 +349,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { ...@@ -354,8 +349,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi) {
int64_t row_numel = 8; int64_t row_numel = 8;
std::vector<int64_t> rows1{5, 2, 5, 3, 5}; std::vector<int64_t> rows1{5, 2, 5, 3, 5};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -364,8 +359,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { ...@@ -364,8 +359,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi) {
set_const(ctx, in1_value, 1.0); set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3}; std::vector<int64_t> rows2{2, 5, 3, 5, 3};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -373,14 +368,13 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { ...@@ -373,14 +368,13 @@ TEST(selected_rows_functor, cpu_merge_add_multi) {
cpu_place); cpu_place);
set_const(ctx, in2_value, 1.0); set_const(ctx, in2_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
output->set_height(height); output->set_height(height);
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float> float>
merge_add_functor; merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs; std::vector<const pten::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get()); inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get()); inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get()); merge_add_functor(ctx, inputs, output.get());
...@@ -411,8 +405,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) { ...@@ -411,8 +405,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) {
int64_t row_numel = 8; int64_t row_numel = 8;
std::vector<int64_t> rows1{1, 3, 5, 7, 9}; std::vector<int64_t> rows1{1, 3, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -421,8 +415,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) { ...@@ -421,8 +415,8 @@ TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) {
set_const(ctx, in1_value, 1.0); set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 2, 4, 6, 8}; std::vector<int64_t> rows2{0, 2, 4, 6, 8};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -430,14 +424,13 @@ TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) { ...@@ -430,14 +424,13 @@ TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) {
cpu_place); cpu_place);
set_const(ctx, in2_value, 2.0); set_const(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
output->set_height(height); output->set_height(height);
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float> float>
merge_add_functor; merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs; std::vector<const pten::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get()); inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get()); inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get()); merge_add_functor(ctx, inputs, output.get());
...@@ -472,8 +465,8 @@ TEST(selected_rows_functor, cpu_sum_to) { ...@@ -472,8 +465,8 @@ TEST(selected_rows_functor, cpu_sum_to) {
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7}; std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -482,8 +475,8 @@ TEST(selected_rows_functor, cpu_sum_to) { ...@@ -482,8 +475,8 @@ TEST(selected_rows_functor, cpu_sum_to) {
functor(ctx, in1_value, 1.0); functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -491,8 +484,7 @@ TEST(selected_rows_functor, cpu_sum_to) { ...@@ -491,8 +484,7 @@ TEST(selected_rows_functor, cpu_sum_to) {
cpu_place); cpu_place);
functor(ctx, in2_value, 2.0); functor(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
output->set_height(height); output->set_height(height);
auto* out_value = output->mutable_value(); auto* out_value = output->mutable_value();
// simplely concat two SelectedRows // simplely concat two SelectedRows
...@@ -501,7 +493,7 @@ TEST(selected_rows_functor, cpu_sum_to) { ...@@ -501,7 +493,7 @@ TEST(selected_rows_functor, cpu_sum_to) {
paddle::operators::math::SelectedRowsSumTo<paddle::platform::CPUDeviceContext, paddle::operators::math::SelectedRowsSumTo<paddle::platform::CPUDeviceContext,
float> float>
sum_to_functor; sum_to_functor;
sum_to_functor(ctx, std::vector<paddle::framework::SelectedRows*>( sum_to_functor(ctx, std::vector<pten::SelectedRows*>(
{selected_rows1.get(), selected_rows2.get()}), {selected_rows1.get(), selected_rows2.get()}),
std::vector<int64_t>({0, in1_value->numel()}), output.get()); std::vector<int64_t>({0, in1_value->numel()}), output.get());
auto out_height = output->height(); auto out_height = output->height();
......
...@@ -29,8 +29,8 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -29,8 +29,8 @@ TEST(selected_rows_functor, gpu_add) {
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7}; std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -48,8 +48,8 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -48,8 +48,8 @@ TEST(selected_rows_functor, gpu_add) {
#endif #endif
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -57,8 +57,7 @@ TEST(selected_rows_functor, gpu_add) { ...@@ -57,8 +57,7 @@ TEST(selected_rows_functor, gpu_add) {
gpu_place); gpu_place);
functor(ctx, in2_value, 2.0); functor(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
auto* out_value = output->mutable_value(); auto* out_value = output->mutable_value();
// simply concat two SelectedRows // simply concat two SelectedRows
...@@ -152,8 +151,8 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -152,8 +151,8 @@ TEST(selected_rows_functor, gpu_add_to) {
int64_t row_numel = 10; int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7}; std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -162,8 +161,8 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -162,8 +161,8 @@ TEST(selected_rows_functor, gpu_add_to) {
functor(ctx, in1_value, 1.0); functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -171,8 +170,7 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -171,8 +170,7 @@ TEST(selected_rows_functor, gpu_add_to) {
gpu_place); gpu_place);
functor(ctx, in2_value, 2.0); functor(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
output->set_height(height); output->set_height(height);
auto* out_value = output->mutable_value(); auto* out_value = output->mutable_value();
...@@ -264,8 +262,8 @@ TEST(selected_rows_functor, gpu_merge_add) { ...@@ -264,8 +262,8 @@ TEST(selected_rows_functor, gpu_merge_add) {
int64_t row_numel = 8; int64_t row_numel = 8;
std::vector<int64_t> rows1{5, 2, 5, 3, 5}; std::vector<int64_t> rows1{5, 2, 5, 3, 5};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{ std::unique_ptr<pten::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)}; new pten::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -274,8 +272,8 @@ TEST(selected_rows_functor, gpu_merge_add) { ...@@ -274,8 +272,8 @@ TEST(selected_rows_functor, gpu_merge_add) {
set_const(ctx, in1_value, 1.0); set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3}; std::vector<int64_t> rows2{2, 5, 3, 5, 3};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<pten::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)}; new pten::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
...@@ -283,14 +281,13 @@ TEST(selected_rows_functor, gpu_merge_add) { ...@@ -283,14 +281,13 @@ TEST(selected_rows_functor, gpu_merge_add) {
gpu_place); gpu_place);
set_const(ctx, in2_value, 1.0); set_const(ctx, in2_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<pten::SelectedRows> output{new pten::SelectedRows()};
new paddle::framework::SelectedRows()};
output->set_height(height); output->set_height(height);
paddle::operators::math::scatter::MergeAdd< paddle::operators::math::scatter::MergeAdd<
paddle::platform::CUDADeviceContext, float> paddle::platform::CUDADeviceContext, float>
merge_add_functor; merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs; std::vector<const pten::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get()); inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get()); inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get()); merge_add_functor(ctx, inputs, output.get());
......
...@@ -51,7 +51,7 @@ class MemcpyD2HFunctor { ...@@ -51,7 +51,7 @@ class MemcpyD2HFunctor {
} }
} }
void operator()(const framework::SelectedRows &rows) const { void operator()(const pten::SelectedRows &rows) const {
// (JZ-LIANG) to support SelectedRows // (JZ-LIANG) to support SelectedRows
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Memcpy for SelectedRows is NOT support yet.")); "Memcpy for SelectedRows is NOT support yet."));
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册