提交 755ad257 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_embedding_to_phi

...@@ -52,12 +52,12 @@ tools/__pycache__ ...@@ -52,12 +52,12 @@ tools/__pycache__
# This file is automatically generated. # This file is automatically generated.
# TODO(zhiqiang) Move this file to build directory. # TODO(zhiqiang) Move this file to build directory.
paddle/infrt/dialect/pd_ops.td paddle/infrt/dialect/pd/ir/pd_ops.td
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td
tools/infrt/kernels.json tools/infrt/kernels.json
tools/infrt/kernel_signature.json tools/infrt/kernel_signature.json
paddle/infrt/dialect/pd_ops_info.h paddle/infrt/dialect/pd/common/pd_ops_info.h
.lit_test_times.txt .lit_test_times.txt
paddle/infrt/tests/dialect/Output paddle/infrt/tests/dialect/Output
paddle/infrt/tests/lit.cfg.py paddle/infrt/tests/lit.cfg.py
......
...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS ...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS
-DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH} -DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH}
-DWITH_STATIC=OFF -DWITH_STATIC=OFF
-DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${PADDLE2ONNX_INSTALL_DIR}/${LIBDIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
......
cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api) cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api)
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api) cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api string_helper)
if (WITH_DISTRIBUTE) if (WITH_DISTRIBUTE)
cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper) cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper)
......
...@@ -171,10 +171,10 @@ ProcessGroupGloo::GlooTask::GlooTask(int rank, ...@@ -171,10 +171,10 @@ ProcessGroupGloo::GlooTask::GlooTask(int rank,
"Only CPU place is supported for ProcessGroupGloo.")); "Only CPU place is supported for ProcessGroupGloo."));
} }
ProcessGroupGloo::ProcessGroupGloo(const std::shared_ptr<GlooStore>& store, ProcessGroupGloo::ProcessGroupGloo(
int rank, int world_size, const std::shared_ptr<paddle::distributed::Store>& store, int rank,
const std::shared_ptr<GlooOptions> options) int world_size, const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size), _tag(0), _store(store) { : ProcessGroup(rank, world_size), _tag(0), _store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size); _context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store = auto prefix_store =
::gloo::rendezvous::PrefixStore(std::to_string(0), *_store); ::gloo::rendezvous::PrefixStore(std::to_string(0), *_store);
......
...@@ -52,8 +52,7 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -52,8 +52,7 @@ class ProcessGroupGloo : public ProcessGroup {
class GlooStore : public ::gloo::rendezvous::Store { class GlooStore : public ::gloo::rendezvous::Store {
public: public:
explicit GlooStore( explicit GlooStore(const std::shared_ptr<paddle::distributed::Store>& store)
const std::shared_ptr<paddle::distributed::TCPStore>& store)
: _store(store) {} : _store(store) {}
~GlooStore() = default; ~GlooStore() = default;
...@@ -87,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -87,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroup {
} }
protected: protected:
std::shared_ptr<paddle::distributed::TCPStore> _store; std::shared_ptr<paddle::distributed::Store> _store;
}; };
class GlooOptions { class GlooOptions {
...@@ -100,9 +99,9 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -100,9 +99,9 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<::gloo::transport::Device> device; std::shared_ptr<::gloo::transport::Device> device;
}; };
explicit ProcessGroupGloo(const std::shared_ptr<GlooStore>& store, int rank, explicit ProcessGroupGloo(
int world_size, const std::shared_ptr<paddle::distributed::Store>& store, int rank,
std::shared_ptr<GlooOptions> options); int world_size, std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
...@@ -145,7 +144,7 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -145,7 +144,7 @@ class ProcessGroupGloo : public ProcessGroup {
protected: protected:
uint32_t _tag; uint32_t _tag;
std::shared_ptr<gloo::rendezvous::Context> _context; std::shared_ptr<gloo::rendezvous::Context> _context;
std::shared_ptr<GlooStore> _store; std::shared_ptr<::gloo::rendezvous::Store> _store;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -17,6 +17,20 @@ ...@@ -17,6 +17,20 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
static Backend TransToBackend(platform::Place place) {
static const std::map<phi::AllocationType, Backend> type_backend = {
{phi::AllocationType::GPU, Backend::GPU},
{phi::AllocationType::CPU, Backend::CPU},
};
phi::AllocationType type = place.GetType();
auto it = type_backend.find(type);
PADDLE_ENFORCE_EQ(it != type_backend.end(), true,
platform::errors::InvalidArgument(
"Place type (%s) is not supported. ", place));
return it->second;
}
std::vector<std::vector<size_t>> Eager_AssignGroupBySize( std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor> tensors, const std::vector<Tensor> tensors,
const std::vector<bool> &is_sparse_gradient, const std::vector<bool> &is_sparse_gradient,
...@@ -297,10 +311,18 @@ EagerReducer::EagerReducer( ...@@ -297,10 +311,18 @@ EagerReducer::EagerReducer(
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node); std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook( accumulation_grad_node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook)); std::make_shared<egr::CppTensorVoidHook>(reduce_hook));
gradnode_index_map_[grad_node.get()] = global_var_index;
} }
vars_marked_ready_.resize(tensors_.size(), false); vars_marked_ready_.resize(tensors_.size(), false);
local_used_vars_.resize(tensors_.size(), 0); local_used_vars_.resize(tensors_.size(), 0);
if (find_unused_vars_each_step_) {
global_used_vars_ = paddle::experimental::empty(
ScalarArray({static_cast<int32_t>(tensors_.size())}), DataType::INT32,
TransToBackend(inner_place_));
}
} }
std::shared_ptr<egr::GradNodeBase> EagerReducer::GetGradNodeFromTensor( std::shared_ptr<egr::GradNodeBase> EagerReducer::GetGradNodeFromTensor(
...@@ -341,21 +363,10 @@ void EagerReducer::InitializeGroups( ...@@ -341,21 +363,10 @@ void EagerReducer::InitializeGroups(
} else { } else {
// process the dense gradient. // process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group); InitializeDenseGroups(tensor_indices_, &group);
experimental::Backend backend; // experimental::Backend backend = TransToBackend(inner_place_);
switch (inner_place_.GetType()) {
case phi::AllocationType::GPU:
backend = experimental::Backend::GPU;
break;
case phi::AllocationType::CPU:
backend = experimental::Backend::CPU;
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Place type (%s) is not supported. ", inner_place_));
break;
}
group.dense_contents_ = paddle::experimental::empty( group.dense_contents_ = paddle::experimental::empty(
ScalarArray({group.all_length_}), group.dtype_, backend); ScalarArray({group.all_length_}), group.dtype_,
TransToBackend(inner_place_));
} }
// map tensors to this group by VariableLocator // map tensors to this group by VariableLocator
...@@ -418,6 +429,53 @@ void EagerReducer::InitializeDenseGroups( ...@@ -418,6 +429,53 @@ void EagerReducer::InitializeDenseGroups(
p_group->all_length_ = all_length; p_group->all_length_ = all_length;
} }
void EagerReducer::TraverseBackwardGraph(const std::vector<Tensor> &outputs) {
std::queue<egr::GradNodeBase *> queue;
std::set<egr::GradNodeBase *> visited;
for (const auto &output : outputs) {
auto *auto_grad_meta =
static_cast<egr::AutogradMeta *>(output.get_autograd_meta());
if (!auto_grad_meta) continue;
auto shared_grad_node = auto_grad_meta->GetMutableGradNode();
if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
auto_grad_meta->StopGradient()) {
continue;
}
egr::GradNodeBase *grad_node = shared_grad_node.get();
queue.emplace(grad_node);
}
while (!queue.empty()) {
egr::GradNodeBase *node = queue.front();
queue.pop();
const std::vector<std::vector<egr::Edge>> &edges = node->GetEdges();
for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) {
const egr::Edge &edge = edges[i][j];
auto next_node_shared = edge.GetMutableGradNode();
if (!next_node_shared || !next_node_shared.get()) {
continue;
}
auto *next_node = next_node_shared.get();
const bool was_inserted = visited.insert(next_node).second;
if (was_inserted) {
queue.emplace(next_node);
}
}
}
}
for (const auto &it : gradnode_index_map_) {
if (visited.count(it.first) == 0) {
unused_vars_.push_back(it.second);
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Tensor " << tensors_[it.second].name() << " at index "
<< it.second << " is marked as unused.";
}
}
}
void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) { void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
VLOG(3) << "after forward, then reset count for backward."; VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = true; grad_need_hooks_ = true;
...@@ -429,6 +487,51 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) { ...@@ -429,6 +487,51 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
// reinitialize vars_marked_ready_ for next iteration // reinitialize vars_marked_ready_ for next iteration
vars_marked_ready_.clear(); vars_marked_ready_.clear();
vars_marked_ready_.resize(tensors_.size(), false); vars_marked_ready_.resize(tensors_.size(), false);
PADDLE_ENFORCE_EQ(
groups_need_finalize_, false,
platform::errors::PreconditionNotMet(
"A serious error has occurred here. Please "
"set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have "
"set, There may be several reasons for this error: "
"1) Please note that all forward outputs derived from the module "
"parameters must participate in the calculation of losses and "
"subsequent gradient calculations. If not, the wrapper will hang, "
"waiting for autograd to generate gradients for these parameters. "
"you can use detach or stop_gradient to make the unused parameters "
"detached from the autograd graph. "
"2) Used multiple forwards and one backward. You may be able to wrap "
"multiple forwards in a model."));
// The first var to trigger the unused parameter
has_marked_unused_vars_ = false;
if (find_unused_vars_once_ || find_unused_vars_each_step_) {
unused_vars_.clear();
TraverseBackwardGraph(outputs);
// only check once in first step
find_unused_vars_once_ = false;
}
if (find_unused_vars_each_step_ && unused_vars_.empty()) {
LOG_FIRST_N(WARNING, 1)
<< "All parameters are involved in the backward pass. "
"It is recommended to set find_unused_parameters to False "
"to improve performance. However, if unused parameters "
"appear in subsequent iterative training, then an error "
"will occur. Please make it clear that in the subsequent "
"training, there will be no parameters that are not used "
"in the backward pass, and then set find_unused_parameters";
}
if (unused_vars_.size() == tensors_.size()) {
LOG_FIRST_N(WARNING, 1)
<< "There is no parameter in the device involved "
"in the backward calculation. If there are "
"parameters on other devices involved in the "
"backward, then a serious error will occur here.";
}
} }
void EagerReducer::AddDistHook(size_t var_index) { void EagerReducer::AddDistHook(size_t var_index) {
...@@ -446,36 +549,104 @@ void EagerReducer::AddDistHook(size_t var_index) { ...@@ -446,36 +549,104 @@ void EagerReducer::AddDistHook(size_t var_index) {
auto &tensor = tensors_[var_index]; auto &tensor = tensors_[var_index];
const auto &grad_node = GetGradNodeFromTensor(&tensor); const auto &grad_node = GetGradNodeFromTensor(&tensor);
VLOG(3) << "Var[" << var_index << "] [" << (*grad_node).name() VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "] arrived and triggered disthook"; << "@Grad] arrived and triggered disthook";
local_used_vars_[var_index] = 1; local_used_vars_[var_index] = 1;
if (!has_marked_unused_vars_) {
has_marked_unused_vars_ = true;
for (const auto unused_index : unused_vars_) {
MarkVarReady(unused_index, false);
}
}
MarkVarReady(var_index, true); MarkVarReady(var_index, true);
} }
void EagerReducer::MarkVarReady(const size_t var_index, void EagerReducer::MarkVarReady(const size_t var_index,
const bool is_used_var) { const bool is_used_var) {
VLOG(3) << "Tensor[" << var_index << "][" << tensors_[var_index].name()
<< "] is marked ready.";
// error happened, if the var is ready before.
if (vars_marked_ready_[var_index]) {
auto error_info = string::Sprintf(
"Error happened, when parameter[%d][%s] has been ready before. "
"Please set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have set, "
"there may be several reasons for this error: "
"1) In multiple reentrant backward phase, some parameters are reused."
"2) Using model parameters outside of forward function. Please "
"make sure that model parameters are not shared in concurrent "
"forward-backward passes.",
var_index, tensors_[var_index].name());
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, false,
platform::errors::PreconditionNotMet(error_info));
error_info +=
"3) Unused parameters retrieval is incorrect. "
"The return value of forward will be used to retrieve"
" the unused parameters of the entire model. These "
"gradients of unused parameters will not be synchronized "
"between multiple cards. However, if the unused "
"parameters participate in the backward calculation "
"again at a later time (e.g. after the forward function, "
"the loss calculation uses the unused "
"paramters of the forward and trigger backward), "
"its gradient will be wrong.";
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, true,
platform::errors::PreconditionNotMet(error_info));
} else {
vars_marked_ready_[var_index] = true;
}
groups_need_finalize_ = true;
const auto &var_locator = variable_locators_[var_index]; const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index; const auto group_index = var_locator.group_index;
const auto inside_group_index = var_locator.inside_group_index; const auto inside_group_index = var_locator.inside_group_index;
auto &group = groups_[group_index]; auto &group = groups_[group_index];
auto &group_tensor = group.dense_tensors_[inside_group_index]; auto &group_tensor = group.dense_tensors_[inside_group_index];
auto *autograd_meta = tensors_[var_index].get_autograd_meta(); const auto length = group.length_[inside_group_index];
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
if (is_used_var) {
group_tensor auto *autograd_meta = tensors_[var_index].get_autograd_meta();
.ShareDataWith( auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl()))) group_tensor
.Resize({grad_tensor.numel()}); .ShareDataWith(
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
vars_marked_ready_[var_index] = true; .Resize({grad_tensor.numel()});
} else {
// TODO(shenliang03): maybe save the memory by avoiding tensor construction
if (!group_tensor.initialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(inner_place_, group.dtype_);
}
if (HasGrad(var_index)) {
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
group_tensor
.ShareDataWith(*(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor->impl())))
.Resize({length});
} else {
VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad";
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)});
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
}
}
if (--group.pending_ == 0) { if (--group.pending_ == 0) {
// can start allreduce // can start allreduce
MarkGroupReady(group_index); MarkGroupReady(group_index);
} }
if (next_group_ == groups_.size()) {
FinalizeBackward();
}
} }
void EagerReducer::MarkGroupReady(size_t group_index) { void EagerReducer::MarkGroupReady(size_t group_index) {
...@@ -501,6 +672,92 @@ void EagerReducer::MarkGroupReady(size_t group_index) { ...@@ -501,6 +672,92 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
} }
} }
bool EagerReducer::HasGrad(size_t var_index) {
auto grad = egr::EagerUtils::mutable_grad(tensors_[var_index]);
if (grad && grad->is_initialized()) {
return true;
} else {
return false;
}
}
void EagerReducer::ProcessUnusedDenseVars() {
// The calculation stream must be used here to
// avoid conflicts with communication.
VLOG(3) << "Local used vars : "
<< string::join_strings(local_used_vars_, ',');
const auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
auto *global_used_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(global_used_vars_.impl())
.get();
framework::TensorFromVector<int32_t>(local_used_vars_, *dev_ctx,
global_used_tensor);
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {global_used_vars_};
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();
framework::TensorToVector<int>(*global_used_tensor, *dev_ctx,
&local_used_vars_);
dev_ctx->Wait();
// sync compute stream to get global used var message,
// but maybe affect speed performance
VLOG(3) << "Global used vars : "
<< string::join_strings(local_used_vars_, ',');
for (const auto var_index : unused_vars_) {
const bool global_unused = (local_used_vars_[var_index] == 0);
// global used but local unused, set grad
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Var [" << var_index << "] [" << tensors_[var_index].name()
<< "] global_unused: " << global_unused
<< " has grad: " << HasGrad(var_index);
if (!global_unused) {
VLOG(3) << "Set Tensor[" << var_index << "]'s Grad for [Rank "
<< process_group_->GetRank() << "]";
const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index;
const auto &group = groups_[group_index];
const auto inside_group_index = var_locator.inside_group_index;
auto &src_tensor = group.dense_tensors_[inside_group_index];
Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor));
auto dest_var_base = tensors_[var_index];
auto grad_tensor = egr::EagerUtils::mutable_grad(dest_var_base);
grad_tensor->copy_(grad_value, inner_place_, true);
grad_tensor->reshape(dest_var_base.shape());
}
}
}
void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
for (auto &group : groups_) {
group.task->Synchronize();
}
for (auto &group : groups_) {
group.SplitTensors(inner_place_);
}
if (find_unused_vars_each_step_) {
ProcessUnusedDenseVars();
local_used_vars_.clear();
local_used_vars_.resize(tensors_.size(), 0);
VLOG(3) << "ProcessUnusedDenseVars is finished.";
}
VLOG(3) << "In the batch, Reducer is finished.";
}
void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
const int curr_group_index) { const int curr_group_index) {
// The overall timeline: concat > div_nranks > allreduce > split // The overall timeline: concat > div_nranks > allreduce > split
...@@ -513,24 +770,14 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, ...@@ -513,24 +770,14 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
group->ConcatTensors(inner_place_); group->ConcatTensors(inner_place_);
// div nranks // div nranks
double scaling = 1.0 / nranks_; paddle::experimental::scale_(group->dense_contents_, 1.0 / nranks_, 0.0,
paddle::experimental::scale_(group->dense_contents_, scaling, 0.0, false); false);
// all_reduce // all_reduce
std::vector<Tensor> reduce_tensors = {group->dense_contents_}; std::vector<Tensor> reduce_tensors = {group->dense_contents_};
tasks_.push_back(process_group_->AllReduce(reduce_tensors, opts)); group->task = process_group_->AllReduce(reduce_tensors, opts);
if (tasks_.size() == groups_.size()) { // split in FinalizeBackward()
for (size_t index = 0; index < tasks_.size(); index++) {
auto &task = tasks_.back();
task->Synchronize();
tasks_.pop_back();
}
for (size_t index = 0; index < groups_.size(); index++) {
auto &group = groups_[index];
group.SplitTensors(inner_place_);
}
}
} }
std::ostream &operator<<(std::ostream &out, const EagerGroup &group) { std::ostream &operator<<(std::ostream &out, const EagerGroup &group) {
......
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/ext_compat_utils.h" #include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/utils/string/string_helper.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -35,6 +37,7 @@ using Tensor = paddle::experimental::Tensor; ...@@ -35,6 +37,7 @@ using Tensor = paddle::experimental::Tensor;
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>; using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
using ScalarArray = using ScalarArray =
paddle::experimental::ScalarArrayBase<paddle::experimental::Tensor>; paddle::experimental::ScalarArrayBase<paddle::experimental::Tensor>;
using Backend = paddle::experimental::Backend;
std::vector<std::vector<size_t>> Eager_AssignGroupBySize( std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor>, const std::vector<bool> &is_sparse_gradient, const std::vector<Tensor>, const std::vector<bool> &is_sparse_gradient,
...@@ -61,6 +64,9 @@ class EagerGroup { ...@@ -61,6 +64,9 @@ class EagerGroup {
// external message of group // external message of group
phi::DataType dtype_; phi::DataType dtype_;
// help to sync
std::shared_ptr<ProcessGroup::Task> task;
// context is used to select the stream for concat // context is used to select the stream for concat
void ConcatTensors(const platform::Place &); void ConcatTensors(const platform::Place &);
...@@ -98,6 +104,10 @@ class EagerReducer { ...@@ -98,6 +104,10 @@ class EagerReducer {
void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(const size_t group_index); void MarkGroupReady(const size_t group_index);
void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index); void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index);
void FinalizeBackward();
void TraverseBackwardGraph(const std::vector<Tensor> &outputs);
void ProcessUnusedDenseVars();
bool HasGrad(size_t var_index);
private: private:
std::vector<Tensor> tensors_; std::vector<Tensor> tensors_;
...@@ -105,7 +115,6 @@ class EagerReducer { ...@@ -105,7 +115,6 @@ class EagerReducer {
std::vector<bool> is_sparse_gradient_; std::vector<bool> is_sparse_gradient_;
std::shared_ptr<distributed::ProcessGroup> process_group_; std::shared_ptr<distributed::ProcessGroup> process_group_;
std::vector<size_t> group_size_limits_; std::vector<size_t> group_size_limits_;
bool find_unused_vars_each_step_;
std::vector<EagerGroup> groups_; std::vector<EagerGroup> groups_;
std::vector<TensorLocator> variable_locators_; std::vector<TensorLocator> variable_locators_;
...@@ -113,12 +122,20 @@ class EagerReducer { ...@@ -113,12 +122,20 @@ class EagerReducer {
platform::Place inner_place_; platform::Place inner_place_;
size_t next_group_ = 0; size_t next_group_ = 0;
int64_t nranks_ = -1; int64_t nranks_ = -1;
std::vector<std::shared_ptr<paddle::distributed::ProcessGroup::Task>> tasks_;
bool grad_need_hooks_{false}; bool grad_need_hooks_{false};
std::vector<bool> vars_marked_ready_; std::vector<bool> vars_marked_ready_;
std::vector<int> local_used_vars_; std::vector<int32_t> local_used_vars_;
// Following variables are to help unused vars
std::vector<size_t> unused_vars_;
std::map<egr::GradNodeBase *, size_t> gradnode_index_map_;
bool has_marked_unused_vars_{false};
bool find_unused_vars_each_step_{false};
bool find_unused_vars_once_{true};
bool groups_need_finalize_{false};
Tensor global_used_vars_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -4,7 +4,7 @@ if(WITH_PYTHON) ...@@ -4,7 +4,7 @@ if(WITH_PYTHON)
endif() endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto) proto_library(interceptor_message_proto SRCS interceptor_message.proto)
if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) if(WITH_DISTRIBUTE AND WITH_PSCORE)
set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog) set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else() else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
......
...@@ -67,8 +67,7 @@ bool MessageBus::IsInit() const { return is_init_; } ...@@ -67,8 +67,7 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() { MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource."; VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000); server_.Stop(1000);
server_.Join(); server_.Join();
#endif #endif
...@@ -87,8 +86,7 @@ bool MessageBus::Send(int64_t dst_rank, ...@@ -87,8 +86,7 @@ bool MessageBus::Send(int64_t dst_rank,
IsInit(), true, IsInit(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized.")); "Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times int retry_time = 0; // message bus will retry sending for 10 times
while (retry_time < 10) { while (retry_time < 10) {
++retry_time; ++retry_time;
...@@ -173,8 +171,7 @@ void MessageBus::ListenPort() { ...@@ -173,8 +171,7 @@ void MessageBus::ListenPort() {
LOG(INFO) << "No need listen to port since training on single card."; LOG(INFO) << "No need listen to port since training on single card.";
return; return;
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message // function keep listen the port and handle the message
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0, server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
...@@ -203,8 +200,7 @@ void MessageBus::ListenPort() { ...@@ -203,8 +200,7 @@ void MessageBus::ListenPort() {
#endif #endif
} }
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
bool MessageBus::SendInterRank(int64_t dst_rank, bool MessageBus::SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
const auto& dst_addr = GetAddr(dst_rank); const auto& dst_addr = GetAddr(dst_rank);
......
...@@ -20,8 +20,7 @@ ...@@ -20,8 +20,7 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h" #include "brpc/channel.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
...@@ -64,8 +63,7 @@ class MessageBus final { ...@@ -64,8 +63,7 @@ class MessageBus final {
const std::string& GetAddr(int64_t rank) const; const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src) // send the message inter rank (dst is different rank with src)
bool SendInterRank(int64_t dst_rank, bool SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message); const InterceptorMessage& interceptor_message);
...@@ -81,8 +79,7 @@ class MessageBus final { ...@@ -81,8 +79,7 @@ class MessageBus final {
// the ip needs to be listened // the ip needs to be listened
std::string addr_; std::string addr_;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
MessageServiceImpl message_service_; MessageServiceImpl message_service_;
// brpc server // brpc server
brpc::Server server_; brpc::Server server_;
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
!defined(PADDLE_WITH_ASCEND_CL)
#pragma once #pragma once
#include "brpc/server.h" #include "brpc/server.h"
......
...@@ -115,6 +115,7 @@ message TableParameter { ...@@ -115,6 +115,7 @@ message TableParameter {
optional CommonAccessorParameter common = 6; optional CommonAccessorParameter common = 6;
optional TableType type = 7; optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ]; optional bool compress_in_save = 8 [ default = false ];
optional GraphParameter graph_parameter = 9;
} }
message TableAccessorParameter { message TableAccessorParameter {
...@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule ...@@ -211,3 +212,25 @@ message SparseAdamSGDParameter { // SparseAdamSGDRule
optional double ada_epsilon = 5 [ default = 1e-08 ]; optional double ada_epsilon = 5 [ default = 1e-08 ];
repeated float weight_bounds = 6; repeated float weight_bounds = 6;
} }
message GraphParameter {
optional int32 task_pool_size = 1 [ default = 24 ];
optional bool gpups_mode = 2 [ default = false ];
optional string gpups_graph_sample_class = 3
[ default = "CompleteGraphSampler" ];
optional string gpups_graph_sample_args = 4 [ default = "" ];
optional bool use_cache = 5 [ default = true ];
optional float cache_ratio = 6 [ default = 0.3 ];
optional int32 cache_ttl = 7 [ default = 5 ];
optional GraphFeature graph_feature = 8;
optional string table_name = 9 [ default = "" ];
optional string table_type = 10 [ default = "" ];
optional int32 gpups_mode_shard_num = 11 [ default = 127 ];
optional int32 gpu_num = 12 [ default = 1 ];
}
message GraphFeature {
repeated string name = 1;
repeated string dtype = 2;
repeated int32 shape = 3;
}
\ No newline at end of file
...@@ -44,7 +44,7 @@ void GraphPsService_Stub::service( ...@@ -44,7 +44,7 @@ void GraphPsService_Stub::service(
} }
} }
int GraphBrpcClient::get_server_index_by_id(uint64_t id) { int GraphBrpcClient::get_server_index_by_id(int64_t id) {
int shard_num = get_shard_num(); int shard_num = get_shard_num();
int shard_per_server = shard_num % server_size == 0 int shard_per_server = shard_num % server_size == 0
? shard_num / server_size ? shard_num / server_size
...@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) { ...@@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
} }
std::future<int32_t> GraphBrpcClient::get_node_feat( std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids, const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) { std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server; std::vector<int> request2server;
...@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -66,7 +66,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
...@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -129,7 +129,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
std::string joint_feature_name = std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t'); paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx) closure->request(request_idx)
...@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) { ...@@ -179,9 +179,9 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::add_graph_node( std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list, uint32_t table_id, std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) { std::vector<bool> &is_weighted_list) {
std::vector<std::vector<uint64_t>> request_bucket; std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket; std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0; bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr; std::vector<int> server_index_arr;
...@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -191,7 +191,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
if (index_mapping[server_index] == -1) { if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size(); index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index); server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>()); request_bucket.push_back(std::vector<int64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>()); if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
} }
request_bucket[index_mapping[server_index]].push_back( request_bucket[index_mapping[server_index]].push_back(
...@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -229,7 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
size_t node_num = request_bucket[request_idx].size(); size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
if (add_weight) { if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1]; bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++) for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
...@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node( ...@@ -248,8 +248,8 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
return fut; return fut;
} }
std::future<int32_t> GraphBrpcClient::remove_graph_node( std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<uint64_t> &node_id_list) { uint32_t table_id, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<uint64_t>> request_bucket; std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr; std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1); std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
...@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -257,7 +257,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
if (index_mapping[server_index] == -1) { if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size(); index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index); server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<uint64_t>()); request_bucket.push_back(std::vector<int64_t>());
} }
request_bucket[index_mapping[server_index]].push_back( request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]); node_id_list[query_idx]);
...@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -291,7 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(), ->add_params((char *)request_bucket[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(get_cmd_channel(server_index)); // PsService_Stub rpc_stub(get_cmd_channel(server_index));
GraphPsService_Stub rpc_stub = GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index)); getServiceStub(get_cmd_channel(server_index));
...@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -303,9 +303,9 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
} }
// char* &buffer,int &actual_size // char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
// std::vector<std::vector<std::pair<uint64_t, float>>> &res, // std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<uint64_t>> &res, std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight, std::vector<std::vector<float>> &res_weight, bool need_weight,
int server_index) { int server_index) {
if (server_index != -1) { if (server_index != -1) {
...@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -337,7 +337,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[node_idx].emplace_back( res[node_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start)); *(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size; start += GraphNode::id_size;
if (need_weight) { if (need_weight) {
res_weight[node_idx].emplace_back( res_weight[node_idx].emplace_back(
...@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -358,7 +358,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->set_table_id(table_id); closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id); closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)node_ids.data(), closure->request(0)->add_params((char *)node_ids.data(),
sizeof(uint64_t) * node_ids.size()); sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int)); closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool)); closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
; ;
...@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -380,14 +380,14 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
request2server.push_back(server_index); request2server.push_back(server_index);
} }
// res.push_back(std::vector<std::pair<uint64_t, float>>()); // res.push_back(std::vector<std::pair<int64_t, float>>());
res.push_back({}); res.push_back({});
if (need_weight) { if (need_weight) {
res_weight.push_back({}); res_weight.push_back({});
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
...@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -428,7 +428,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
int start = 0; int start = 0;
while (start < actual_size) { while (start < actual_size) {
res[query_idx].emplace_back( res[query_idx].emplace_back(
*(uint64_t *)(node_buffer + offset + start)); *(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size; start += GraphNode::id_size;
if (need_weight) { if (need_weight) {
res_weight[query_idx].emplace_back( res_weight[query_idx].emplace_back(
...@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -459,7 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
...@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -476,7 +476,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
} }
std::future<int32_t> GraphBrpcClient::random_sample_nodes( std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size, uint32_t table_id, int server_index, int sample_size,
std::vector<uint64_t> &ids) { std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = (DownpourBrpcClosure *)done;
...@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes( ...@@ -490,7 +490,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0; int index = 0;
while (index < bytes_size) { while (index < bytes_size) {
ids.push_back(*(uint64_t *)(buffer + index)); ids.push_back(*(int64_t *)(buffer + index));
index += GraphNode::id_size; index += GraphNode::id_size;
} }
delete[] buffer; delete[] buffer;
...@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -633,7 +633,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
} }
std::future<int32_t> GraphBrpcClient::set_node_feat( std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<uint64_t> &node_ids, const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) { const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server; std::vector<int> request2server;
...@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -646,7 +646,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
} }
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets( std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num); request_call_num);
...@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -696,7 +696,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
std::string joint_feature_name = std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t'); paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx) closure->request(request_idx)
......
...@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -63,8 +63,8 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {} virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them // given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors( virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size, uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
std::vector<std::vector<uint64_t>>& res, std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight, std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1); int server_index = -1);
...@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -75,20 +75,20 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id, virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index, int server_index,
int sample_size, int sample_size,
std::vector<uint64_t>& ids); std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat( virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids, const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res); std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat( virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<uint64_t>& node_ids, const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names, const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features); const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id); virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node( virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list, uint32_t table_id, std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list); std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id, virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit, size_t size_limit,
...@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient { ...@@ -96,11 +96,11 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> load_graph_split_config(uint32_t table_id, virtual std::future<int32_t> load_graph_split_config(uint32_t table_id,
std::string path); std::string path);
virtual std::future<int32_t> remove_graph_node( virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list); uint32_t table_id, std::vector<int64_t>& node_id_list);
virtual int32_t initialize(); virtual int32_t initialize();
int get_shard_num() { return shard_num; } int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; } void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(uint64_t id); int get_server_index_by_id(int64_t id);
void set_local_channel(int index) { void set_local_channel(int index) {
this->local_channel = get_cmd_channel(index); this->local_channel = get_cmd_channel(index);
} }
......
...@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table, ...@@ -140,9 +140,9 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list; std::vector<bool> is_weighted_list;
if (request.params_size() == 2) { if (request.params_size() == 2) {
size_t weight_list_size = request.params(1).size() / sizeof(bool); size_t weight_list_size = request.params(1).size() / sizeof(bool);
...@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ...@@ -165,9 +165,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"graph_get_node_feat request requires at least 1 argument"); "graph_get_node_feat request requires at least 1 argument");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(node_ids); ((GraphTable *)table)->remove_graph_node(node_ids);
return 0; return 0;
...@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -386,9 +386,9 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments"); "graph_random_sample_neighbors request requires at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(2).c_str()); bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num); std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
...@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -407,7 +407,7 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
int32_t GraphBrpcService::graph_random_sample_nodes( int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
size_t size = *(uint64_t *)(request.params(0).c_str()); size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer; std::unique_ptr<char[]> buffer;
int actual_size; int actual_size;
if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) == if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) ==
...@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -430,9 +430,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 2 arguments"); "graph_get_node_feat request requires at least 2 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names = std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t"); paddle::string::split_string<std::string>(request.params(1), "\t");
...@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -464,16 +464,16 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
"at least 3 arguments"); "at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t), size_t node_num = request.params(0).size() / sizeof(int64_t),
size_of_size_t = sizeof(size_t); size_of_size_t = sizeof(size_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str()); int sample_size = *(int64_t *)(request.params(1).c_str());
bool need_weight = *(uint64_t *)(request.params(2).c_str()); bool need_weight = *(int64_t *)(request.params(2).c_str());
// std::vector<uint64_t> res = ((GraphTable // std::vector<int64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size); // *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server; std::vector<int> request2server;
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
std::vector<uint64_t> local_id; std::vector<int64_t> local_id;
std::vector<int> local_query_idx; std::vector<int> local_query_idx;
size_t rank = get_rank(); size_t rank = get_rank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
...@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -496,7 +496,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<std::shared_ptr<char>> local_buffers; std::vector<std::shared_ptr<char>> local_buffers;
std::vector<int> local_actual_sizes; std::vector<int> local_actual_sizes;
std::vector<size_t> seq; std::vector<size_t> seq;
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (int query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index =
...@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -583,7 +583,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num); sizeof(int64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx) closure->request(request_idx)
...@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, ...@@ -618,9 +618,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments"); "graph_set_node_feat request requires at least 3 arguments");
return 0; return 0;
} }
size_t node_num = request.params(0).size() / sizeof(uint64_t); size_t node_num = request.params(0).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str()); int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names = std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(1), "\t"); paddle::string::split_string<std::string>(request.params(1), "\t");
......
...@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name, ...@@ -44,9 +44,9 @@ void GraphPyService::add_table_feat_conf(std::string table_name,
} }
} }
void add_graph_node(std::vector<uint64_t> node_ids, void add_graph_node(std::vector<int64_t> node_ids,
std::vector<bool> weight_list) {} std::vector<bool> weight_list) {}
void remove_graph_node(std::vector<uint64_t> node_ids) {} void remove_graph_node(std::vector<int64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str, int shard_num, void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> node_types, std::vector<std::string> node_types,
std::vector<std::string> edge_types) { std::vector<std::string> edge_types) {
...@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) { ...@@ -260,7 +260,7 @@ void GraphPyClient::clear_nodes(std::string name) {
} }
void GraphPyClient::add_graph_node(std::string name, void GraphPyClient::add_graph_node(std::string name,
std::vector<uint64_t>& node_ids, std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list) { std::vector<bool>& weight_list) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
...@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name, ...@@ -271,7 +271,7 @@ void GraphPyClient::add_graph_node(std::string name,
} }
void GraphPyClient::remove_graph_node(std::string name, void GraphPyClient::remove_graph_node(std::string name,
std::vector<uint64_t>& node_ids) { std::vector<int64_t>& node_ids) {
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = get_ps_client()->remove_graph_node(table_id, node_ids); auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
...@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) { ...@@ -290,13 +290,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
} }
} }
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name, GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids, std::vector<int64_t> node_ids,
int sample_size, bool return_weight, int sample_size, bool return_weight,
bool return_edges) { bool return_edges) {
// std::vector<std::vector<std::pair<uint64_t, float>>> v; std::vector<std::vector<int64_t>> v;
std::vector<std::vector<uint64_t>> v;
std::vector<std::vector<float>> v1; std::vector<std::vector<float>> v1;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
...@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name, ...@@ -309,7 +308,7 @@ GraphPyClient::batch_sample_neighbors(std::string name,
// res.first[1]: slice index // res.first[1]: slice index
// res.first[2]: src nodes // res.first[2]: src nodes
// res.second: edges weight // res.second: edges weight
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res; std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res.first.push_back({}); res.first.push_back({});
res.first.push_back({}); res.first.push_back({});
if (return_edges) res.first.push_back({}); if (return_edges) res.first.push_back({});
...@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name, ...@@ -342,10 +341,10 @@ void GraphPyClient::use_neighbors_sample_cache(std::string name,
status.wait(); status.wait();
} }
} }
std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name, std::vector<int64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index, int server_index,
int sample_size) { int sample_size) {
std::vector<uint64_t> v; std::vector<int64_t> v;
if (this->table_id_map.count(name)) { if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name]; uint32_t table_id = this->table_id_map[name];
auto status = auto status =
...@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name, ...@@ -357,7 +356,7 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
// (name, dtype, ndarray) // (name, dtype, ndarray)
std::vector<std::vector<std::string>> GraphPyClient::get_node_feat( std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names) { std::vector<std::string> feature_names) {
std::vector<std::vector<std::string>> v( std::vector<std::vector<std::string>> v(
feature_names.size(), std::vector<std::string>(node_ids.size())); feature_names.size(), std::vector<std::string>(node_ids.size()));
...@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat( ...@@ -371,7 +370,7 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
} }
void GraphPyClient::set_node_feat( void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names, std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) { const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) { if (this->table_id_map.count(node_type)) {
......
...@@ -70,18 +70,34 @@ class GraphPyService { ...@@ -70,18 +70,34 @@ class GraphPyService {
::paddle::distributed::TableAccessorParameter* accessor_proto = ::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor(); sparse_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto = // ::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto->mutable_common(); // sparse_table_proto->mutable_common();
::paddle::distributed::GraphParameter* graph_proto =
sparse_table_proto->mutable_graph_parameter();
::paddle::distributed::GraphFeature* graph_feature =
graph_proto->mutable_graph_feature();
graph_proto->set_task_pool_size(24);
graph_proto->set_table_name(table_name);
graph_proto->set_table_type(table_type);
graph_proto->set_use_cache(false);
// Set GraphTable Parameter // Set GraphTable Parameter
common_proto->set_table_name(table_name); // common_proto->set_table_name(table_name);
common_proto->set_name(table_type); // common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
for (size_t i = 0; i < feat_name.size(); i++) { for (size_t i = 0; i < feat_name.size(); i++) {
common_proto->add_params(feat_dtype[i]); graph_feature->add_dtype(feat_dtype[i]);
common_proto->add_dims(feat_shape[i]); graph_feature->add_shape(feat_shape[i]);
common_proto->add_attributes(feat_name[i]); graph_feature->add_name(feat_name[i]);
} }
accessor_proto->set_accessor_class("CommMergeAccessor"); accessor_proto->set_accessor_class("CommMergeAccessor");
} }
...@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService { ...@@ -143,24 +159,24 @@ class GraphPyClient : public GraphPyService {
void load_edge_file(std::string name, std::string filepath, bool reverse); void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath); void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name); void clear_nodes(std::string name);
void add_graph_node(std::string name, std::vector<uint64_t>& node_ids, void add_graph_node(std::string name, std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list); std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<uint64_t>& node_ids); void remove_graph_node(std::string name, std::vector<int64_t>& node_ids);
int get_client_id() { return client_id; } int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; } void set_client_id(int client_id) { this->client_id = client_id; }
void start_client(); void start_client();
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name, std::vector<uint64_t> node_ids, batch_sample_neighbors(std::string name, std::vector<int64_t> node_ids,
int sample_size, bool return_weight, int sample_size, bool return_weight,
bool return_edges); bool return_edges);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index, std::vector<int64_t> random_sample_nodes(std::string name, int server_index,
int sample_size); int sample_size);
std::vector<std::vector<std::string>> get_node_feat( std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids, std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names); std::vector<std::string> feature_names);
void use_neighbors_sample_cache(std::string name, size_t total_size_limit, void use_neighbors_sample_cache(std::string name, size_t total_size_limit,
size_t ttl); size_t ttl);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids, void set_node_feat(std::string node_type, std::vector<int64_t> node_ids,
std::vector<std::string> feature_names, std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features); const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index, std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
......
...@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro ...@@ -53,7 +53,6 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table) cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)
cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp) target_link_libraries(table -fopenmp)
...@@ -38,10 +38,14 @@ ...@@ -38,10 +38,14 @@
#include <vector> #include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h" #include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h" #include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class GraphShard { class GraphShard {
...@@ -51,37 +55,37 @@ class GraphShard { ...@@ -51,37 +55,37 @@ class GraphShard {
~GraphShard(); ~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; } std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step); std::vector<Node *> get_batch(int start, int end, int step);
std::vector<uint64_t> get_ids_by_range(int start, int end) { std::vector<int64_t> get_ids_by_range(int start, int end) {
std::vector<uint64_t> res; std::vector<int64_t> res;
for (int i = start; i < end && i < (int)bucket.size(); i++) { for (int i = start; i < end && i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id()); res.push_back(bucket[i]->get_id());
} }
return res; return res;
} }
GraphNode *add_graph_node(uint64_t id); GraphNode *add_graph_node(int64_t id);
GraphNode *add_graph_node(Node *node); GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(uint64_t id); FeatureNode *add_feature_node(int64_t id);
Node *find_node(uint64_t id); Node *find_node(int64_t id);
void delete_node(uint64_t id); void delete_node(int64_t id);
void clear(); void clear();
void add_neighbor(uint64_t id, uint64_t dst_id, float weight); void add_neighbor(int64_t id, int64_t dst_id, float weight);
std::unordered_map<uint64_t, int> &get_node_location() { std::unordered_map<int64_t, int> &get_node_location() {
return node_location; return node_location;
} }
private: private:
std::unordered_map<uint64_t, int> node_location; std::unordered_map<int64_t, int> node_location;
std::vector<Node *> bucket; std::vector<Node *> bucket;
}; };
enum LRUResponse { ok = 0, blocked = 1, err = 2 }; enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey { struct SampleKey {
uint64_t node_key; int64_t node_key;
size_t sample_size; size_t sample_size;
bool is_weighted; bool is_weighted;
SampleKey(uint64_t _node_key, size_t _sample_size, bool _is_weighted) SampleKey(int64_t _node_key, size_t _sample_size, bool _is_weighted)
: node_key(_node_key), : node_key(_node_key),
sample_size(_sample_size), sample_size(_sample_size),
is_weighted(_is_weighted) {} is_weighted(_is_weighted) {}
...@@ -300,7 +304,7 @@ class ScaledLRU { ...@@ -300,7 +304,7 @@ class ScaledLRU {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count; node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
} }
if (node_size <= size_t(1.1 * size_limit) + 1) return 0; if ((size_t)node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) { if (pthread_rwlock_wrlock(&rwlock) == 0) {
// VLOG(0)<"in shrink\n"; // VLOG(0)<"in shrink\n";
global_count = 0; global_count = 0;
...@@ -308,9 +312,9 @@ class ScaledLRU { ...@@ -308,9 +312,9 @@ class ScaledLRU {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count; global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
} }
// VLOG(0)<<"global_count "<<global_count<<"\n"; // VLOG(0)<<"global_count "<<global_count<<"\n";
if (global_count > size_limit) { if ((size_t)global_count > size_limit) {
size_t remove = global_count - size_limit; size_t remove = global_count - size_limit;
for (int i = 0; i < lru_pool.size(); i++) { for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0; lru_pool[i].total_diff = 0;
lru_pool[i].remove_count += lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) / 1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
...@@ -352,9 +356,69 @@ class ScaledLRU { ...@@ -352,9 +356,69 @@ class ScaledLRU {
friend class RandomSampleLRU<K, V>; friend class RandomSampleLRU<K, V>;
}; };
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
return;
};
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
return -1;
}
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
}
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) = 0;
virtual void set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
this->callback = callback;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
virtual GraphSamplerStatus get_graph_sampler_status() { return status; }
protected:
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
class GraphTable : public SparseTable { class GraphTable : public SparseTable {
public: public:
GraphTable() { use_cache = false; } GraphTable() {
use_cache = false;
shard_num = 0;
#ifdef PADDLE_WITH_HETERPS
gpups_mode = false;
#endif
rw_lock.reset(new pthread_rwlock_t());
}
virtual ~GraphTable(); virtual ~GraphTable();
virtual int32_t pull_graph_list(int start, int size, virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer, std::unique_ptr<char[]> &buffer,
...@@ -362,7 +426,7 @@ class GraphTable : public SparseTable { ...@@ -362,7 +426,7 @@ class GraphTable : public SparseTable {
int step); int step);
virtual int32_t random_sample_neighbors( virtual int32_t random_sample_neighbors(
uint64_t *node_ids, int sample_size, int64_t *node_ids, int sample_size,
std::vector<std::shared_ptr<char>> &buffers, std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes, bool need_weight); std::vector<int> &actual_sizes, bool need_weight);
...@@ -370,9 +434,11 @@ class GraphTable : public SparseTable { ...@@ -370,9 +434,11 @@ class GraphTable : public SparseTable {
int &actual_sizes); int &actual_sizes);
virtual int32_t get_nodes_ids_by_ranges( virtual int32_t get_nodes_ids_by_ranges(
std::vector<std::pair<int, int>> ranges, std::vector<uint64_t> &res); std::vector<std::pair<int, int>> ranges, std::vector<int64_t> &res);
virtual int32_t initialize(); virtual int32_t initialize() { return 0; }
virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t initialize(const GraphParameter &config);
int32_t load(const std::string &path, const std::string &param); int32_t load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path); int32_t load_graph_split_config(const std::string &path);
...@@ -380,13 +446,13 @@ class GraphTable : public SparseTable { ...@@ -380,13 +446,13 @@ class GraphTable : public SparseTable {
int32_t load_nodes(const std::string &path, std::string node_type); int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(std::vector<uint64_t> &id_list, int32_t add_graph_node(std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list); std::vector<bool> &is_weight_list);
int32_t remove_graph_node(std::vector<uint64_t> &id_list); int32_t remove_graph_node(std::vector<int64_t> &id_list);
int32_t get_server_index_by_id(uint64_t id); int32_t get_server_index_by_id(int64_t id);
Node *find_node(uint64_t id); Node *find_node(int64_t id);
virtual int32_t pull_sparse(float *values, virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) { const PullSparseValue &pull_value) {
...@@ -407,16 +473,27 @@ class GraphTable : public SparseTable { ...@@ -407,16 +473,27 @@ class GraphTable : public SparseTable {
return 0; return 0;
} }
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index); virtual int32_t set_shard(size_t shard_idx, size_t server_num) {
virtual uint32_t get_thread_pool_index(uint64_t node_id); _shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num = server_num;
this->server_num = server_num;
return 0;
}
virtual uint32_t get_thread_pool_index_by_shard_index(int64_t shard_index);
virtual uint32_t get_thread_pool_index(int64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str); virtual std::pair<int32_t, std::string> parse_feature(std::string feat_str);
virtual int32_t get_node_feat(const std::vector<uint64_t> &node_ids, virtual int32_t get_node_feat(const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res); std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat( virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names, const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res); const std::vector<std::vector<std::string>> &res);
...@@ -433,11 +510,25 @@ class GraphTable : public SparseTable { ...@@ -433,11 +510,25 @@ class GraphTable : public SparseTable {
} }
return 0; return 0;
} }
#ifdef PADDLE_WITH_HETERPS
virtual int32_t start_graph_sampling() {
return this->graph_sampler->start_graph_sampling();
}
virtual int32_t end_graph_sampling() {
return this->graph_sampler->end_graph_sampling();
}
virtual int32_t set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
graph_sampler->set_graph_sample_callback(callback);
return 0;
}
// virtual GraphSampler *get_graph_sampler() { return graph_sampler.get(); }
#endif
protected: protected:
std::vector<GraphShard *> shards, extra_shards; std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num; size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
const int task_pool_size_ = 24; int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3; const int random_sample_nodes_ranges = 3;
std::vector<std::string> feat_name; std::vector<std::string> feat_name;
...@@ -450,11 +541,61 @@ class GraphTable : public SparseTable { ...@@ -450,11 +541,61 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool; std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru; std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<uint64_t> extra_nodes; std::unordered_set<int64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index; std::unordered_map<int64_t, size_t> extra_nodes_to_thread_index;
bool use_cache, use_duplicate_nodes; bool use_cache, use_duplicate_nodes;
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
bool gpups_mode;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
std::shared_ptr<GraphSampler> graph_sampler;
REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
};
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
public:
CompleteGraphSampler() {}
~CompleteGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
class BasicBfsGraphSampler : public GraphSampler {
public:
BasicBfsGraphSampler() {}
~BasicBfsGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
size_t gpu_num;
int node_num_for_each_shard, edge_num_for_each_node;
int rounds, interval;
std::vector<std::unordered_map<int64_t, std::vector<int64_t>>>
sample_neighbors_map;
}; };
#endif
} // namespace distributed } // namespace distributed
}; // namespace paddle }; // namespace paddle
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h" #include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle { namespace paddle {
...@@ -117,13 +118,9 @@ class TruncatedGaussianInitializer : public Initializer { ...@@ -117,13 +118,9 @@ class TruncatedGaussianInitializer : public Initializer {
seed_ = static_cast<unsigned int>(std::stoi(attrs[1])); seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]); mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]); std_ = std::stof(attrs[3]);
auto normal_cdf = [](float x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; std::uniform_real_distribution<float> dist_(
}; std::numeric_limits<float>::min(), 1.0);
float a_normal_cdf = normal_cdf((-2.0 - mean_) / std_);
float b_normal_cdf = normal_cdf((2.0 - mean_) / std_);
std::uniform_real_distribution<float> dist_(2.0 * a_normal_cdf - 1.0,
2.0 * b_normal_cdf - 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_); random_engine_ = framework::GetCPURandomEngine(seed_);
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define DECLARE_GRAPH_FRIEND_CLASS(a) friend class a;
#define DECLARE_1_FRIEND_CLASS(a, ...) DECLARE_GRAPH_FRIEND_CLASS(a)
#define DECLARE_2_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_1_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_3_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_2_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_4_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_3_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_5_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_4_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_6_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_5_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_7_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_6_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_8_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_7_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_9_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_8_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_10_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_9_FRIEND_CLASS(__VA_ARGS__)
#define DECLARE_11_FRIEND_CLASS(a, ...) \
DECLARE_GRAPH_FRIEND_CLASS(a) DECLARE_10_FRIEND_CLASS(__VA_ARGS__)
#define REGISTER_GRAPH_FRIEND_CLASS(n, ...) \
DECLARE_##n##_FRIEND_CLASS(__VA_ARGS__)
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void GraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { void GraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id); id_arr.push_back(id);
} }
void WeightedGraphEdgeBlob::add_edge(uint64_t id, float weight = 1) { void WeightedGraphEdgeBlob::add_edge(int64_t id, float weight = 1) {
id_arr.push_back(id); id_arr.push_back(id);
weight_arr.push_back(weight); weight_arr.push_back(weight);
} }
......
...@@ -24,19 +24,20 @@ class GraphEdgeBlob { ...@@ -24,19 +24,20 @@ class GraphEdgeBlob {
GraphEdgeBlob() {} GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {} virtual ~GraphEdgeBlob() {}
size_t size() { return id_arr.size(); } size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight); virtual void add_edge(int64_t id, float weight);
uint64_t get_id(int idx) { return id_arr[idx]; } int64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; } virtual float get_weight(int idx) { return 1; }
std::vector<int64_t>& export_id_array() { return id_arr; }
protected: protected:
std::vector<uint64_t> id_arr; std::vector<int64_t> id_arr;
}; };
class WeightedGraphEdgeBlob : public GraphEdgeBlob { class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public: public:
WeightedGraphEdgeBlob() {} WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {} virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight); virtual void add_edge(int64_t id, float weight);
virtual float get_weight(int idx) { return weight_arr[idx]; } virtual float get_weight(int idx) { return weight_arr[idx]; }
protected: protected:
......
...@@ -48,6 +48,7 @@ class Node { ...@@ -48,6 +48,7 @@ class Node {
virtual void set_feature(int idx, std::string str) {} virtual void set_feature(int idx, std::string str) {}
virtual void set_feature_size(int size) {} virtual void set_feature_size(int size) {}
virtual int get_feature_size() { return 0; } virtual int get_feature_size() { return 0; }
virtual size_t get_neighbor_size() { return 0; }
protected: protected:
uint64_t id; uint64_t id;
...@@ -70,6 +71,7 @@ class GraphNode : public Node { ...@@ -70,6 +71,7 @@ class GraphNode : public Node {
} }
virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); } virtual uint64_t get_neighbor_id(int idx) { return edges->get_id(idx); }
virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); } virtual float get_neighbor_weight(int idx) { return edges->get_weight(idx); }
virtual size_t get_neighbor_size() { return edges->size(); }
protected: protected:
Sampler *sampler; Sampler *sampler;
......
...@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable); ...@@ -37,6 +37,8 @@ REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS(Table, SSDSparseTable); REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
REGISTER_PSCORE_CLASS(GraphSampler, CompleteGraphSampler);
REGISTER_PSCORE_CLASS(GraphSampler, BasicBfsGraphSampler);
#endif #endif
REGISTER_PSCORE_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, BarrierTable);
......
...@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv ...@@ -24,6 +24,9 @@ cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope serv
set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(graph_table_sample_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_table_sample_test SRCS graph_table_sample_test.cc DEPS scope server communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table) cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
......
...@@ -236,7 +236,7 @@ void RunGraphSplit() { ...@@ -236,7 +236,7 @@ void RunGraphSplit() {
sleep(2); sleep(2);
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions; std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert( dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {})); std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0]; auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service()); RunClient(dense_regions, 0, pserver_ptr_->get_service());
...@@ -250,16 +250,16 @@ void RunGraphSplit() { ...@@ -250,16 +250,16 @@ void RunGraphSplit() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<uint64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs; std::vector<std::vector<float>> vs;
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, _vs[0].size()); ASSERT_EQ(0, _vs[0].size());
_vs.clear(); _vs.clear();
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 97), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 97), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(3, _vs[0].size()); ASSERT_EQ(3, _vs[0].size());
std::remove(edge_file_name); std::remove(edge_file_name);
......
...@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed; ...@@ -48,10 +48,10 @@ namespace distributed = paddle::distributed;
void testSampleNodes( void testSampleNodes(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<uint64_t> ids; std::vector<int64_t> ids;
auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids); auto pull_status = worker_ptr_->random_sample_nodes(0, 0, 6, ids);
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {37, 59}; std::unordered_set<int64_t> s1 = {37, 59};
pull_status.wait(); pull_status.wait();
for (auto id : ids) s.insert(id); for (auto id : ids) s.insert(id);
ASSERT_EQ(true, s.size() == s1.size()); ASSERT_EQ(true, s.size() == s1.size());
...@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() { ...@@ -106,14 +106,14 @@ void testFeatureNodeSerializeFloat64() {
void testSingleSampleNeighboor( void testSingleSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs; std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1; std::vector<std::vector<float>> vs1;
auto pull_status = worker_ptr_->batch_sample_neighbors( auto pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 4, vs, vs1, true); 0, std::vector<int64_t>(1, 37), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g); s.insert(g);
} }
...@@ -126,7 +126,7 @@ void testSingleSampleNeighboor( ...@@ -126,7 +126,7 @@ void testSingleSampleNeighboor(
vs.clear(); vs.clear();
vs1.clear(); vs1.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 96), 4, vs, vs1, true); 0, std::vector<int64_t>(1, 96), 4, vs, vs1, true);
pull_status.wait(); pull_status.wait();
s1 = {111, 48, 247}; s1 = {111, 48, 247};
for (auto g : vs[0]) { for (auto g : vs[0]) {
...@@ -147,30 +147,30 @@ void testAddNode( ...@@ -147,30 +147,30 @@ void testAddNode(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
worker_ptr_->clear_nodes(0); worker_ptr_->clear_nodes(0);
int total_num = 270000; int total_num = 270000;
uint64_t id; int64_t id;
std::unordered_set<uint64_t> id_set; std::unordered_set<int64_t> id_set;
for (int i = 0; i < total_num; i++) { for (int i = 0; i < total_num; i++) {
while (id_set.find(id = rand()) != id_set.end()) while (id_set.find(id = rand()) != id_set.end())
; ;
id_set.insert(id); id_set.insert(id);
} }
std::vector<uint64_t> id_list(id_set.begin(), id_set.end()); std::vector<int64_t> id_list(id_set.begin(), id_set.end());
std::vector<bool> weight_list; std::vector<bool> weight_list;
auto status = worker_ptr_->add_graph_node(0, id_list, weight_list); auto status = worker_ptr_->add_graph_node(0, id_list, weight_list);
status.wait(); status.wait();
std::vector<uint64_t> ids[2]; std::vector<int64_t> ids[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto sample_status = auto sample_status =
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait(); sample_status.wait();
} }
std::unordered_set<uint64_t> id_set_check(ids[0].begin(), ids[0].end()); std::unordered_set<int64_t> id_set_check(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check.insert(x); for (auto x : ids[1]) id_set_check.insert(x);
ASSERT_EQ(id_set.size(), id_set_check.size()); ASSERT_EQ(id_set.size(), id_set_check.size());
for (auto x : id_set) { for (auto x : id_set) {
ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true); ASSERT_EQ(id_set_check.find(x) != id_set_check.end(), true);
} }
std::vector<uint64_t> remove_ids; std::vector<int64_t> remove_ids;
for (auto p : id_set_check) { for (auto p : id_set_check) {
if (remove_ids.size() == 0) if (remove_ids.size() == 0)
remove_ids.push_back(p); remove_ids.push_back(p);
...@@ -187,7 +187,7 @@ void testAddNode( ...@@ -187,7 +187,7 @@ void testAddNode(
worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]); worker_ptr_->random_sample_nodes(0, i, total_num, ids[i]);
sample_status.wait(); sample_status.wait();
} }
std::unordered_set<uint64_t> id_set_check1(ids[0].begin(), ids[0].end()); std::unordered_set<int64_t> id_set_check1(ids[0].begin(), ids[0].end());
for (auto x : ids[1]) id_set_check1.insert(x); for (auto x : ids[1]) id_set_check1.insert(x);
ASSERT_EQ(id_set_check1.size(), id_set_check.size()); ASSERT_EQ(id_set_check1.size(), id_set_check.size());
for (auto x : id_set_check1) { for (auto x : id_set_check1) {
...@@ -196,14 +196,14 @@ void testAddNode( ...@@ -196,14 +196,14 @@ void testAddNode(
} }
void testBatchSampleNeighboor( void testBatchSampleNeighboor(
std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) { std::shared_ptr<paddle::distributed::GraphBrpcClient>& worker_ptr_) {
std::vector<std::vector<uint64_t>> vs; std::vector<std::vector<int64_t>> vs;
std::vector<std::vector<float>> vs1; std::vector<std::vector<float>> vs1;
std::vector<std::uint64_t> v = {37, 96}; std::vector<std::int64_t> v = {37, 96};
auto pull_status = auto pull_status =
worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false); worker_ptr_->batch_sample_neighbors(0, v, 4, vs, vs1, false);
pull_status.wait(); pull_status.wait();
std::unordered_set<uint64_t> s; std::unordered_set<int64_t> s;
std::unordered_set<uint64_t> s1 = {112, 45, 145}; std::unordered_set<int64_t> s1 = {112, 45, 145};
for (auto g : vs[0]) { for (auto g : vs[0]) {
s.insert(g); s.insert(g);
} }
...@@ -417,7 +417,7 @@ void RunBrpcPushSparse() { ...@@ -417,7 +417,7 @@ void RunBrpcPushSparse() {
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions; std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert( dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {})); std::pair<int64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0]; auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service()); RunClient(dense_regions, 0, pserver_ptr_->get_service());
...@@ -427,14 +427,14 @@ void RunBrpcPushSparse() { ...@@ -427,14 +427,14 @@ void RunBrpcPushSparse() {
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>")); worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0)); srand(time(0));
pull_status.wait(); pull_status.wait();
std::vector<std::vector<uint64_t>> _vs; std::vector<std::vector<int64_t>> _vs;
std::vector<std::vector<float>> vs; std::vector<std::vector<float>> vs;
testSampleNodes(worker_ptr_); testSampleNodes(worker_ptr_);
sleep(5); sleep(5);
testSingleSampleNeighboor(worker_ptr_); testSingleSampleNeighboor(worker_ptr_);
testBatchSampleNeighboor(worker_ptr_); testBatchSampleNeighboor(worker_ptr_);
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true); 0, std::vector<int64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(0, _vs[0].size()); ASSERT_EQ(0, _vs[0].size());
paddle::distributed::GraphTable* g = paddle::distributed::GraphTable* g =
...@@ -445,14 +445,14 @@ void RunBrpcPushSparse() { ...@@ -445,14 +445,14 @@ void RunBrpcPushSparse() {
while (round--) { while (round--) {
vs.clear(); vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, _vs, vs, false); 0, std::vector<int64_t>(1, 37), 1, _vs, vs, false);
pull_status.wait(); pull_status.wait();
for (int i = 0; i < ttl; i++) { for (int i = 0; i < ttl; i++) {
std::vector<std::vector<uint64_t>> vs1; std::vector<std::vector<int64_t>> vs1;
std::vector<std::vector<float>> vs2; std::vector<std::vector<float>> vs2;
pull_status = worker_ptr_->batch_sample_neighbors( pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 37), 1, vs1, vs2, false); 0, std::vector<int64_t>(1, 37), 1, vs1, vs2, false);
pull_status.wait(); pull_status.wait();
ASSERT_EQ(_vs[0].size(), vs1[0].size()); ASSERT_EQ(_vs[0].size(), vs1[0].size());
...@@ -540,7 +540,7 @@ void RunBrpcPushSparse() { ...@@ -540,7 +540,7 @@ void RunBrpcPushSparse() {
// Test Pull by step // Test Pull by step
std::unordered_set<uint64_t> count_item_nodes; std::unordered_set<int64_t> count_item_nodes;
// pull by step 2 // pull by step 2
for (int test_step = 1; test_step < 4; test_step++) { for (int test_step = 1; test_step < 4; test_step++) {
count_item_nodes.clear(); count_item_nodes.clear();
...@@ -558,18 +558,18 @@ void RunBrpcPushSparse() { ...@@ -558,18 +558,18 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12); ASSERT_EQ(count_item_nodes.size(), 12);
} }
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res; std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res = client1.batch_sample_neighbors( res = client1.batch_sample_neighbors(
std::string("user2item"), std::vector<uint64_t>(1, 96), 4, true, false); std::string("user2item"), std::vector<int64_t>(1, 96), 4, true, false);
ASSERT_EQ(res.first[0].size(), 3); ASSERT_EQ(res.first[0].size(), 3);
std::vector<uint64_t> node_ids; std::vector<int64_t> node_ids;
node_ids.push_back(96); node_ids.push_back(96);
node_ids.push_back(37); node_ids.push_back(37);
res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4, res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4,
true, false); true, false);
ASSERT_EQ(res.first[1].size(), 1); ASSERT_EQ(res.first[1].size(), 1);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6); std::vector<int64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2); ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) || ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
(nodes_ids[0] == 37 && nodes_ids[1] == 59)); (nodes_ids[0] == 37 && nodes_ids[1] == 59));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <vector>
#include "google/protobuf/text_format.h"
#include <chrono>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
std::vector<std::string> edges = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
std::string("96\t247\t0.31"), std::string("96\t111\t1.21"),
std::string("59\t45\t0.34"), std::string("59\t145\t0.31"),
std::string("59\t122\t0.21"), std::string("97\t48\t0.34"),
std::string("97\t247\t0.31"), std::string("97\t111\t0.21")};
// odd id:96 48 122 112
char edge_file_name[] = "edges.txt";
std::vector<std::string> nodes = {
std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"),
std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"),
std::string("user\t59\ta 0.11\tb 11 14"),
std::string("user\t97\ta 0.11\tb 12 11"),
std::string("item\t45\ta 0.21"),
std::string("item\t145\ta 0.21"),
std::string("item\t112\ta 0.21"),
std::string("item\t48\ta 0.21"),
std::string("item\t247\ta 0.21"),
std::string("item\t111\ta 0.21"),
std::string("item\t46\ta 0.21"),
std::string("item\t146\ta 0.21"),
std::string("item\t122\ta 0.21"),
std::string("item\t49\ta 0.21"),
std::string("item\t248\ta 0.21"),
std::string("item\t113\ta 0.21")};
char node_file_name[] = "nodes.txt";
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
void testGraphSample() {
#ifdef PADDLE_WITH_HETERPS
::paddle::distributed::GraphParameter table_proto;
table_proto.set_gpups_mode(true);
table_proto.set_gpups_mode_shard_num(127);
table_proto.set_gpu_num(2);
distributed::GraphTable graph_table, graph_table1;
graph_table.initialize(table_proto);
prepare_file(edge_file_name, edges);
graph_table.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res;
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_table.set_graph_sample_callback(
[&res, &prom](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res = res0;
prom.set_value(0);
});
graph_table.start_graph_sampling();
fut.get();
graph_table.end_graph_sampling();
ASSERT_EQ(2, res.size());
// 37 59 97
for (int i = 0; i < (int)res[1].node_size; i++) {
std::cout << res[1].node_list[i].node_id << std::endl;
}
ASSERT_EQ(3, res[1].node_size);
::paddle::distributed::GraphParameter table_proto1;
table_proto1.set_gpups_mode(true);
table_proto1.set_gpups_mode_shard_num(127);
table_proto1.set_gpu_num(2);
table_proto1.set_gpups_graph_sample_class("BasicBfsGraphSampler");
table_proto1.set_gpups_graph_sample_args("5,5,1,1");
graph_table1.initialize(table_proto1);
graph_table1.load(std::string(edge_file_name), std::string("e>"));
std::vector<paddle::framework::GpuPsCommGraph> res1;
std::promise<int> prom1;
std::future<int> fut1 = prom1.get_future();
graph_table1.set_graph_sample_callback(
[&res1, &prom1](std::vector<paddle::framework::GpuPsCommGraph> &res0) {
res1 = res0;
prom1.set_value(0);
});
graph_table1.start_graph_sampling();
fut1.get();
graph_table1.end_graph_sampling();
// distributed::BasicBfsGraphSampler *sampler1 =
// (distributed::BasicBfsGraphSampler *)graph_table1.get_graph_sampler();
// sampler1->start_graph_sampling();
// std::this_thread::sleep_for (std::chrono::seconds(1));
// std::vector<paddle::framework::GpuPsCommGraph> res1;// =
// sampler1->fetch_sample_res();
ASSERT_EQ(2, res1.size());
// odd id:96 48 122 112
for (int i = 0; i < (int)res1[0].node_size; i++) {
std::cout << res1[0].node_list[i].node_id << std::endl;
}
ASSERT_EQ(4, res1[0].node_size);
#endif
}
TEST(testGraphSample, Run) { testGraphSample(); }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "glog/logging.h" #include "glog/logging.h"
DECLARE_bool(retain_grad_for_all_tensor);
namespace egr { namespace egr {
static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
...@@ -39,8 +39,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -39,8 +39,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
} }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()( operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) { bool create_graph) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation"; VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1, PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -62,7 +62,7 @@ operator()( ...@@ -62,7 +62,7 @@ operator()(
grad_out = grads[0][0]; grad_out = grads[0][0];
} }
if (!weak_grad_.expired()) { if (!weak_grad_.expired() && FLAGS_retain_grad_for_all_tensor) {
auto grad = weak_grad_.lock(); auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out); CopyOrAddTensor(grad.get(), grad_out);
} }
......
...@@ -35,8 +35,15 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -35,8 +35,15 @@ class GradNodeAccumulation : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
override; bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
std::string name() { return "GradNodeAccumulation"; } std::string name() { return "GradNodeAccumulation"; }
......
...@@ -145,8 +145,8 @@ void GradNodeScale::SetTensorWrappers_X( ...@@ -145,8 +145,8 @@ void GradNodeScale::SetTensorWrappers_X(
void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; } void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
operator()( operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) { bool create_graph) {
// 1. Check Output Size // 1. Check Output Size
PADDLE_ENFORCE( PADDLE_ENFORCE(
((grads.size() == 1) && (grads[0].size() == 1)), ((grads.size() == 1) && (grads[0].size() == 1)),
......
...@@ -39,8 +39,15 @@ class GradNodeScale : public GradNodeBase { ...@@ -39,8 +39,15 @@ class GradNodeScale : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
override; bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetTensorWrappers_X( void SetTensorWrappers_X(
const std::vector<paddle::experimental::Tensor>& tensors); const std::vector<paddle::experimental::Tensor>& tensors);
......
...@@ -86,9 +86,9 @@ paddle::experimental::Tensor scale(const paddle::experimental::Tensor& x, ...@@ -86,9 +86,9 @@ paddle::experimental::Tensor scale(const paddle::experimental::Tensor& x,
scale_node->SetTensorWrappers_X({x}); scale_node->SetTensorWrappers_X({x});
// Set Grad out rank as same as fwd input and set stop gradient to bwd // Set Grad out rank as same as fwd input and set stop gradient to bwd
scale_node->SetGradOutMeta(p_autograd_in, /*slot id*/ 0); scale_node->SetGradOutMeta(x, /*slot id*/ 0);
// Set Grad out rank as same as fwd input and set stop gradient to bwd // Set Grad out rank as same as fwd input and set stop gradient to bwd
scale_node->SetGradInMeta(p_autograd_out, /*slot id*/ 0); scale_node->SetGradInMeta(out, /*slot id*/ 0);
// Set History for output set current Grad Node for // Set History for output set current Grad Node for
EagerUtils::SetHistory(p_autograd_out, scale_node); EagerUtils::SetHistory(p_autograd_out, scale_node);
......
...@@ -30,7 +30,8 @@ namespace egr_utils_api { ...@@ -30,7 +30,8 @@ namespace egr_utils_api {
bool IsLeafTensor(const paddle::experimental::Tensor& target) { bool IsLeafTensor(const paddle::experimental::Tensor& target) {
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(target); std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(target);
if (std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node)) { if (!grad_node ||
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node)) {
return true; return true;
} }
......
...@@ -27,6 +27,7 @@ add_custom_target(eager_final_state_codegen ...@@ -27,6 +27,7 @@ add_custom_target(eager_final_state_codegen
set(tmp_python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h") set(tmp_python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h")
set(python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h") set(python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h")
add_custom_target(eager_final_state_python_c_codegen add_custom_target(eager_final_state_python_c_codegen
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py" COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path}" "--api_yaml_path=${api_yaml_path}"
......
...@@ -28,6 +28,7 @@ namespace = "" ...@@ -28,6 +28,7 @@ namespace = ""
yaml_types_mapping = { yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \
'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', 'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor', 'Tensor' : 'Tensor',
...@@ -212,7 +213,8 @@ def ParseYamlArgs(string): ...@@ -212,7 +213,8 @@ def ParseYamlArgs(string):
default_value = m.group(3).split("=")[1].strip() if len( default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys() assert arg_type in yaml_types_mapping.keys(
), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping."
arg_type = yaml_types_mapping[arg_type] arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name) arg_name = RemoveSpecialSymbolsInName(arg_name)
...@@ -247,7 +249,8 @@ def ParseYamlReturns(string): ...@@ -247,7 +249,8 @@ def ParseYamlReturns(string):
else: else:
ret_type = ret.strip() ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys() assert ret_type in yaml_types_mapping.keys(
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type] ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type assert "Tensor" in ret_type
...@@ -475,6 +478,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -475,6 +478,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
# SetTensorWrapper Methods & TensorWrapper Members # SetTensorWrapper Methods & TensorWrapper Members
set_tensor_wrapper_methods_str = "" set_tensor_wrapper_methods_str = ""
tensor_wrapper_members_str = "" tensor_wrapper_members_str = ""
clear_tensor_wrapper_str = ""
for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items(): for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items():
if tname in no_need_buffer_set: if tname in no_need_buffer_set:
no_need_buffer = "true" no_need_buffer = "true"
...@@ -496,6 +500,13 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -496,6 +500,13 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
""" """
tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name) tensor_wrapper_name)
CLEAR_TENSOR_WRAPPERS_TEMPLATE = """
{}.clear();
"""
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """
...@@ -513,6 +524,15 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -513,6 +524,15 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
""" """
tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
tensor_wrapper_name) tensor_wrapper_name)
CLEAR_TENSOR_WRAPPERS_TEMPLATE = """
for (auto tw: {}) {
tw.clear();
};
"""
clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format(
tensor_wrapper_name)
# End: SetTensorWrapper Methods & TensorWrapper Members # End: SetTensorWrapper Methods & TensorWrapper Members
# SetAttributes & Attribute Members # SetAttributes & Attribute Members
...@@ -521,7 +541,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -521,7 +541,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
for aname, atype, default_val, _ in backward_attrs_list: for aname, atype, default_val, _ in backward_attrs_list:
saved_attr_name = GetSavedName(aname) saved_attr_name = GetSavedName(aname)
SET_ATTR_METHOD_TEMPLATE = """ SET_ATTR_METHOD_TEMPLATE = """
void SetAttribute{}({} {}) {{ void SetAttribute{}({} {}) {{
{} = {}; {} = {};
}} }}
""" """
...@@ -552,25 +572,37 @@ class {} : public egr::GradNodeBase {{ ...@@ -552,25 +572,37 @@ class {} : public egr::GradNodeBase {{
~{}() override = default; ~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) override; const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }} std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ... // SetTensorWrapperX, SetTensorWrapperY, ...
{} {}
// SetAttributes // SetAttributes
{} {}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private: private:
// TensorWrappers // TensorWrappers
{} {}
bool is_tensor_wrappers_cleared = false;
// Attributes // Attributes
{} {}
}}; }};
""" """
node_declaration_str = NODE_DECLARATION_TEMPLATE.format( node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, set_tensor_wrapper_methods_str, grad_node_name, clear_tensor_wrapper_str,
set_attribute_methods_str, tensor_wrapper_members_str, set_tensor_wrapper_methods_str, set_attribute_methods_str,
attribute_members_str) tensor_wrapper_members_str, attribute_members_str)
return node_declaration_str return node_declaration_str
...@@ -624,6 +656,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -624,6 +656,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
else: else:
# Rearrange output order accordingly # Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n" returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name) grad_node_name = GetGradNodeName(fwd_api_name)
...@@ -634,7 +667,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -634,7 +667,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_namespace = f"paddle::experimental" grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """ FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Call grad_api function // Call grad_api function
auto grad_api_returns = {}::{}({}); auto grad_api_returns = {}::{}({});
{} {}
...@@ -697,7 +730,7 @@ def GenerateNodeCreationCodes( ...@@ -697,7 +730,7 @@ def GenerateNodeCreationCodes(
else: else:
# Tuple api_result # Tuple api_result
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);" output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n" output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
...@@ -734,8 +767,11 @@ def GenerateNodeCreationCodes( ...@@ -734,8 +767,11 @@ def GenerateNodeCreationCodes(
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
if IsVectorTensorType(atype): if num_fwd_outputs > 1:
tw_name = f"api_result[{pos}]" # Aligned with forward output position
assert name in forward_outputs_position_map.keys()
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else: else:
tw_name = f"api_result" tw_name = f"api_result"
...@@ -751,7 +787,7 @@ def GenerateNodeCreationCodes( ...@@ -751,7 +787,7 @@ def GenerateNodeCreationCodes(
set_edges_list = [] set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items(): for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name) input_autograd_meta_name = GetAutoGradMetaName(name)
set_grad_out_meta = f" grad_node->SetGradOutMeta({input_autograd_meta_name}, {pos});" set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta) set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges) set_edges_list.append(set_edges)
...@@ -768,17 +804,18 @@ def GenerateNodeCreationCodes( ...@@ -768,17 +804,18 @@ def GenerateNodeCreationCodes(
output_autograd_meta_name = GetAutoGradMetaName(name) output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_grad_in_meta = f" grad_node->SetGradInMeta({output_autograd_meta_name}, {pos});" if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank) set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history) set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta) set_grad_in_meta_list.append(set_grad_in_meta)
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result[{pos}]);"
set_retain_grad_list.append(set_retain_grad) set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list) set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list) set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list) set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
...@@ -900,7 +937,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -900,7 +937,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
returns_list[0] = f"api_result" returns_list[0] = f"api_result"
else: else:
# Tuple api_result # Tuple api_result
returns_list[pos] = f"api_result[{pos}]" returns_list[pos] = f"std::get<{pos}>(api_result)"
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
returns_type_list[pos] = "paddle::experimental::Tensor" returns_type_list[pos] = "paddle::experimental::Tensor"
...@@ -1050,7 +1087,7 @@ def GenerateNodeCCFile(filepath, node_definition_str): ...@@ -1050,7 +1087,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h" #include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/phi/api/include/sparse_api.h" #include "paddle/phi/api/backward/sparse_bw_api.h"
""" """
file_contents += node_definition_str file_contents += node_definition_str
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
...@@ -1245,7 +1282,7 @@ if __name__ == "__main__": ...@@ -1245,7 +1282,7 @@ if __name__ == "__main__":
# Node Definition Generation # Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition( definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, orig_forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs, backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs) intermediate_outputs)
...@@ -1257,7 +1294,7 @@ if __name__ == "__main__": ...@@ -1257,7 +1294,7 @@ if __name__ == "__main__":
# For python-level API dispatch # For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_outputs_position_map,
forward_attrs_list) orig_forward_attrs_list)
if len(namespace) > 0: if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{ forward_definition_str += f"""namespace {namespace} {{
......
...@@ -39,12 +39,21 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap( ...@@ -39,12 +39,21 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
// Copy nodes // Copy nodes
std::queue<GradNodeBase*> queue = init_queue; std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited; std::unordered_set<GradNodeBase*> visited;
size_t potential_startup_ops_cnt = queue.size();
size_t cnt = 0;
// Visit each node exactly once in any order // Visit each node exactly once in any order
while (!queue.empty()) { while (!queue.empty()) {
GradNodeBase* node = queue.front(); GradNodeBase* node = queue.front();
queue.pop(); queue.pop();
if (cnt < potential_startup_ops_cnt) {
if (!node_in_degree_map.count(node)) {
node_in_degree_map[node] = 0;
}
cnt += 1;
}
if (visited.count(node)) { if (visited.count(node)) {
continue; continue;
} }
...@@ -76,23 +85,248 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap( ...@@ -76,23 +85,248 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
return node_in_degree_map; return node_in_degree_map;
} }
void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, // Remove some nodes those doesn't need to be
const std::vector<paddle::experimental::Tensor>& grad_tensors, // stored in potential_stop_nodes、potential_startup_nodes
bool retain_graph) { void UpdateGraphInfo(
paddle::platform::RecordEvent backward_record_event( std::unordered_map<GradNodeBase*, AutogradMeta*>*
"backward", paddle::platform::TracerEventType::Operator, 1); target_nodes_inputmeta_map,
std::unordered_map<GradNodeBase*, std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
// Updated potential_sotp_nodes by depending_nodes,
// make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> _startup_ops;
VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : *target_nodes_inputmeta_map) {
queue.emplace(target_nodes_inputmeta_pair.first);
}
while (!queue.empty()) {
auto* target_node = queue.front();
queue.pop();
if (!(*depending_nodes)[target_node].empty()) {
auto precedding_nodes = (*depending_nodes)[target_node];
for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes);
if (potential_stop_nodes->find(pre_nodes) !=
potential_stop_nodes->end()) {
potential_stop_nodes->erase(pre_nodes);
}
}
} else { // startup_ops have no precedding nodes
VLOG(6) << "Emplace _startup_ops";
_startup_ops.emplace(target_node);
}
}
// Purify potential_startup_nodes again, remove some
// potential startup_nodes that unreach to input target nodes
if (!_startup_ops.empty()) {
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto node : *potential_startup_nodes) {
if (_startup_ops.count(node) == 0) {
VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
potential_startup_nodes_to_be_erased.emplace(node);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto node : potential_startup_nodes_to_be_erased) {
VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
potential_startup_nodes->erase(node);
}
}
}
}
// Get Graph Info Betweent input target gradnode and outputs,
// record depending_nodes、 potential_stop_nodes、potential_startup_nodes
void GetGraphInfoBetweenTargets(
const std::queue<GradNodeBase*>& init_queue,
std::unordered_map<GradNodeBase*, AutogradMeta*>*
input_target_nodes_inputmeta_map,
std::unordered_map</*child node*/ GradNodeBase*,
/*father nodes*/ std::unordered_set<GradNodeBase*>>*
depending_nodes,
std::unordered_set<GradNodeBase*>* potential_stop_nodes,
std::unordered_set<GradNodeBase*>* potential_startup_nodes) {
if (input_target_nodes_inputmeta_map->empty()) return;
VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
// Calculate in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map;
// Copy nodes
std::queue<GradNodeBase*> queue = init_queue;
std::unordered_set<GradNodeBase*> visited;
// Visit each node exactly once in any order
while (!queue.empty()) {
GradNodeBase* node = queue.front();
queue.pop();
if (visited.count(node)) {
continue;
}
visited.insert(node);
// Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes
bool is_potential_stop_nodes =
input_target_nodes_inputmeta_map->count(node);
// Find and append next nodes
const std::vector<std::vector<Edge>>& edges = node->GetEdges();
for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// if node not in input_target_nodes,
// all the next_nodes of current node will be inserted to
// potential_stop_node
if (is_potential_stop_nodes) {
potential_stop_nodes->emplace(next_node);
}
// Update in_degree
if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0;
node_in_degree_map[next_node]++;
// Record depending relationship
(*depending_nodes)[next_node].emplace(node);
queue.push(next_node);
}
}
}
// Update Graph Info, remove some stop_node in potential_stop_nodes
UpdateGraphInfo(input_target_nodes_inputmeta_map, depending_nodes,
potential_stop_nodes, potential_startup_nodes);
}
void GetTargetNodesInfo(const std::vector<paddle::experimental::Tensor>& inputs,
std::unordered_map<GradNodeBase*, AutogradMeta*>*
target_nodes_inputmeta_map) {
VLOG(6) << "Running in GetTargetNodesInfo";
if (!inputs.empty()) {
VLOG(6) << "Inputs are not empty";
size_t num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for input:%d or it's"
"stop_gradient=True",
i));
(*target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
}
}
}
std::vector<paddle::experimental::Tensor> GetResults(
const std::vector<paddle::experimental::Tensor>& inputs,
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor>*
results_map,
bool allow_unused, bool create_graph) {
VLOG(6) << "Running in GetResults";
if (inputs.empty()) return {};
std::vector<paddle::experimental::Tensor> results;
results.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto iter = results_map->find(target_node);
if (iter != results_map->end()) {
// set StopGradient = !create_graph
AutogradMeta* tensor_auto_grad_meta =
EagerUtils::autograd_meta(&(iter->second));
tensor_auto_grad_meta->SetStopGradient(!create_graph);
results.emplace_back(iter->second);
} else {
PADDLE_ENFORCE_EQ(allow_unused, true,
paddle::platform::errors::InvalidArgument(
"The %d-th input does not appear in the backward "
"graph. Please check the input variable or set "
"allow_unused=True to get None result.",
i));
results.emplace_back();
}
}
return results;
}
// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
VLOG(6) << "Running in EnforceGradNodeHasInput";
PADDLE_ENFORCE_NE(
node->IsTensorWrappersCleared(), true,
paddle::platform::errors::Fatal(
"The TensorWrappers of %s do not exist. This may be because:\n"
"You calculate backward twice for the same subgraph without "
"setting retain_graph=True. Please set retain_graph=True in the "
"first backward/grad call.\n",
node->name()));
}
// Purify potential_startup_nodes, remove nodes those are the same as
// input_target_nodes
void PurifyPotentialStartUpNodes(
std::unordered_set<GradNodeBase*>* potential_startup_nodes,
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>*
input_target_nodes_inputmeta_map) {
VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map->empty()) return;
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : *potential_startup_nodes) {
auto iter = input_target_nodes_inputmeta_map->find(startup_op);
if (iter != input_target_nodes_inputmeta_map->end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes->erase(nodes);
}
}
}
std::vector<paddle::experimental::Tensor> RunBackward(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph = false,
const std::vector<paddle::experimental::Tensor>& inputs = {},
bool allow_unused = false,
const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
VLOG(6) << "Start Backward"; VLOG(6) << "Start Backward";
// *Gradient Hook should happen at node-level // *Gradient Hook should happen at node-level
// *Inplace version check should perform at node-level // *Inplace version check should perform at node-level
// *Cross-batch accumulation happens at forward pass // *Cross-batch accumulation happens at forward pass
std::unordered_map<GradNodeBase*, AutogradMeta*>
no_grad_var_nodes_inputmeta_map;
// Get no_grad_vars's GradNodes and InputMeta Info
GetTargetNodesInfo(no_grad_vars, &no_grad_var_nodes_inputmeta_map);
/* --- Initialization --- */ /* --- Initialization --- */
// 1. Init queue with starting nodes // 1. Init queue with starting nodes
// 2. Prepare initial input buffers // 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue; std::queue<GradNodeBase*> queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>> std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict; node_input_buffers_dict;
std::unordered_set<GradNodeBase*> potential_startup_nodes;
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
const paddle::experimental::Tensor& tensor = tensors[i]; const paddle::experimental::Tensor& tensor = tensors[i];
...@@ -132,8 +366,17 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -132,8 +366,17 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
"size = 0 or same size as tensors")); "size = 0 or same size as tensors"));
// Feed given tensor if it's provided // Feed given tensor if it's provided
VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor"; VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";
node_input_buffers_dict[grad_node]->add(
input_info.first, input_info.second, grad_tensors[i]); if (grad_tensors[i].is_initialized()) {
// Deep copy
paddle::experimental::Tensor tmp_tensor;
tmp_tensor.copy_(grad_tensors[i], grad_tensors[i].inner_place(), true);
node_input_buffers_dict[grad_node]->add(input_info.first,
input_info.second, tmp_tensor);
} else {
node_input_buffers_dict[grad_node]->add(
input_info.first, input_info.second, grad_tensors[i]);
}
} else { } else {
VLOG(6) << "Fill grad input tensor " << i << " with 1.0"; VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
...@@ -146,8 +389,9 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -146,8 +389,9 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
input_info.first, input_info.second, tensor, true /*fill_one=true*/); input_info.first, input_info.second, tensor, true /*fill_one=true*/);
} }
// Prepare queue // Prepare queue, potential startup_nodes
queue.push(grad_node); queue.push(grad_node);
potential_startup_nodes.emplace(grad_node);
} }
VLOG(6) << "Update In degree Map for backward"; VLOG(6) << "Update In degree Map for backward";
...@@ -155,25 +399,74 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -155,25 +399,74 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
std::unordered_map<GradNodeBase*, int> node_in_degree_map = std::unordered_map<GradNodeBase*, int> node_in_degree_map =
getInDegreeMap(queue); getInDegreeMap(queue);
// Get input's GradNodes and InputMeta Info
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
input_target_nodes_inputmeta_map;
GetTargetNodesInfo(inputs, &input_target_nodes_inputmeta_map);
// Purify potential_startup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes(&potential_startup_nodes,
&input_target_nodes_inputmeta_map);
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and potential_stop_nodes
std::unordered_map<GradNodeBase* /* child node */,
std::unordered_set<GradNodeBase*> /* father node */>
depending_nodes;
std::unordered_set<GradNodeBase*> potential_stop_nodes;
// std::unordered_set<GradNodeBase*> startup_ops;
GetGraphInfoBetweenTargets(queue, &input_target_nodes_inputmeta_map,
&depending_nodes, &potential_stop_nodes,
&potential_startup_nodes);
// ready_queue store all startup nodes
std::queue<GradNodeBase*> ready_queue;
// startup op's indegree should be 0
for (auto node : potential_startup_nodes) {
if (node_in_degree_map[node] == 0) {
ready_queue.emplace(node);
}
}
VLOG(1) << " startup_ops' size is :" << ready_queue.size();
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
// read_queue is empty only when 1.input equals to output. 2.input can not
// reach to output.
if (ready_queue.size() == 0) {
for (auto input_target_node : input_target_nodes_inputmeta_map) {
// out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo();
if (node_input_buffers_dict[input_target_node.first]) {
auto& target_result =
node_input_buffers_dict[input_target_node.first]
->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[input_target_node.first] = target_result;
}
}
}
/* --- Topological Visit --- */ /* --- Topological Visit --- */
// 1. Pop queue // 1. Pop queue
// 2. Run node // 2. Run node
// |- Check and capture target result
// |- node(grads) // |- node(grads)
// |- Prepare for next node // |- Prepare for next node
// 3. Update queue // 3. Update queue
VLOG(6) << "Run Backward"; VLOG(6) << "Run Backward";
while (!queue.empty()) { while (!ready_queue.empty()) {
GradNodeBase* node = queue.front(); GradNodeBase* node = ready_queue.front();
VLOG(6) << "Running GradNode:" << node->name();
ready_queue.pop();
paddle::platform::RecordEvent node_record_event( paddle::platform::RecordEvent node_record_event(
std::string(typeid(*node).name()) + " grad_node", std::string(typeid(*node).name()) + " grad_node",
paddle::platform::TracerEventType::Operator, 1); paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) {
queue.pop();
continue;
}
queue.pop();
// Run node: This is where Hook happens // Run node: This is where Hook happens
PADDLE_ENFORCE( PADDLE_ENFORCE(
node_input_buffers_dict.count(node), node_input_buffers_dict.count(node),
...@@ -184,16 +477,51 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -184,16 +477,51 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
std::unique_ptr<GradTensorHolder> node_input_buffer = std::unique_ptr<GradTensorHolder> node_input_buffer =
std::move(node_input_buffers_dict[node]); std::move(node_input_buffers_dict[node]);
// get target grad_var from node_input_buffer by inputmeta
if (input_target_nodes_inputmeta_map.find(node) !=
input_target_nodes_inputmeta_map.end()) {
VLOG(6) << "Get target result by by inputmeta";
// out rank_info of forward op
auto rank_info = input_target_nodes_inputmeta_map[node]->OutRankInfo();
// rank_info is a pair, first means slot_id, second means rank.
auto& target_result =
node_input_buffer->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[node] = target_result;
}
// no_grad_vars
if (no_grad_var_nodes_inputmeta_map.find(node) !=
no_grad_var_nodes_inputmeta_map.end()) {
VLOG(6) << "Change the input buffer[slot][rank] by Zeros";
auto rank_info = no_grad_var_nodes_inputmeta_map[node]->OutRankInfo();
node_input_buffer->SetBufferSlotRankZeros(rank_info.first,
rank_info.second);
}
VLOG(6) << "Running GradNode:" << node->name();
// check input
EnforceGradNodeHasInput(node);
VLOG(6) << "Run Backward Kernel with GradTensorHolder"; VLOG(6) << "Run Backward Kernel with GradTensorHolder";
// Run Pre Backward Node and get outputs // Run Pre Backward Node and get outputs
std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors = std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
(*node)(node_input_buffer->Buffers()); (*node)(node_input_buffer->Buffers(), create_graph);
// retain_grad or not
if (!retain_graph) {
VLOG(6)
<< "retain_graph is false, need to clear the TensorWrapper of nodes.";
node->ClearTensorWrappers();
}
// TODO(jiabin): Should we erase it or find a more efficient way. // TODO(jiabin): Should we erase it or find a more efficient way.
node_input_buffers_dict.erase(node); node_input_buffers_dict.erase(node);
// Prepare GradTensorHolder for next node // Prepare GradTensorHolder for next node
const std::vector<std::vector<Edge>>& edges = node->GetEdges(); const std::vector<std::vector<Edge>>& edges = node->GetEdges();
PADDLE_ENFORCE(edges.size() == grad_output_tensors.size() || edges.empty(), PADDLE_ENFORCE(edges.size() == grad_output_tensors.size() || edges.empty(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Number of edges should be either empty ( for leaf node " "Number of edges should be either empty ( for leaf node "
...@@ -204,6 +532,7 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -204,6 +532,7 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
for (size_t i = 0; i < edges.size(); i++) { for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) { for (size_t j = 0; j < edges[i].size(); j++) {
const Edge& edge = edges[i][j]; const Edge& edge = edges[i][j];
auto edge_rank = edge.GetEdgeRankInfo(); auto edge_rank = edge.GetEdgeRankInfo();
// Since we make edge has as same rank as bwd outputs, we indexing them // Since we make edge has as same rank as bwd outputs, we indexing them
// with // with
...@@ -217,6 +546,7 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -217,6 +546,7 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
grad_output_tensors[i].empty()) { grad_output_tensors[i].empty()) {
continue; continue;
} }
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
j, grad_output_tensors[i].size(), j, grad_output_tensors[i].size(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -252,18 +582,44 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors, ...@@ -252,18 +582,44 @@ void RunBackward(const std::vector<paddle::experimental::Tensor>& tensors,
// Update queue // Update queue
node_in_degree_map[next_node]--; node_in_degree_map[next_node]--;
PADDLE_ENFORCE( PADDLE_ENFORCE(
node_in_degree_map[next_node] >= 0, node_in_degree_map[next_node] >= 0,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Detected in-degree value smaller than zero. For Node: %s" "Detected in-degree value smaller than zero. For Node: %s"
"Node's in-degree cannot be negative", "Node's in-degree cannot be negative",
next_node->name())); next_node->name()));
if (node_in_degree_map[next_node] == 0) {
queue.emplace(std::move(next_node)); bool is_potential_stop_node = potential_stop_nodes.count(next_node);
if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) {
ready_queue.emplace(std::move(next_node));
} }
} }
} }
} }
return GetResults(inputs, &results_map, allow_unused, create_graph);
} }
void Backward(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph) {
VLOG(6) << "Run in Backward";
paddle::platform::RecordEvent backward_record_event(
"backward", paddle::platform::TracerEventType::Operator, 1);
RunBackward(tensors, grad_tensors, retain_graph);
}
std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& tensors, // output
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
VLOG(6) << "Run in Grad";
return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs,
allow_unused, no_grad_vars);
}
} // namespace egr } // namespace egr
...@@ -19,12 +19,20 @@ ...@@ -19,12 +19,20 @@
namespace egr { namespace egr {
// run_backward(): // Backward():
// tensors corresponds to those lived in the backward graph // tensors corresponds to those lived in the backward graph
// each grad_tensors[i] keeps the value for its corresponding tensors[i] // each grad_tensors[i] keeps the value for its corresponding tensors[i]
void RunBackward(const std::vector<paddle::experimental::Tensor> &tensors, void Backward(const std::vector<paddle::experimental::Tensor>& tensors,
const std::vector<paddle::experimental::Tensor> &grad_tensors, const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph = false); bool retain_graph = false);
std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& tensors,
const std::vector<paddle::experimental::Tensor>& inputs,
const std::vector<paddle::experimental::Tensor>& grad_tensors = {},
bool retain_graph = false, bool create_graph = false,
bool only_inputs = false, bool allow_unused = false,
const std::vector<paddle::experimental::Tensor>& no_grad_vars = {});
// Reserved for gradient() // Reserved for gradient()
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
namespace egr { namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode:: std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode::
operator()( operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) { bool create_graph) {
paddle::CustomOpKernelContext ctx; paddle::CustomOpKernelContext ctx;
auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs( auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
......
...@@ -37,8 +37,8 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -37,8 +37,8 @@ class RunCustomOpNode : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
override; bool create_graph) override;
std::string name() { std::string name() {
return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_); return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_);
...@@ -62,6 +62,12 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -62,6 +62,12 @@ class RunCustomOpNode : public GradNodeBase {
return res; return res;
} }
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; } void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }
public: public:
......
...@@ -57,21 +57,28 @@ class AutogradMeta; ...@@ -57,21 +57,28 @@ class AutogradMeta;
class GradSlotMeta { class GradSlotMeta {
public: public:
GradSlotMeta() = default; GradSlotMeta() = default;
void Init(size_t size) { bool IsStopGradient() const { return stop_gradient_; }
size_ = static_cast<int>(size); void SetStopGradient(bool stop_gradient = true) {
stop_gradient_.resize(size, false); stop_gradient_ = stop_gradient;
} }
bool IsInitialized() const { return size_ != -1; } void SetTensorMeta(const phi::DenseTensorMeta& meta) {
bool IsStopGradient(size_t rank) const { return stop_gradient_[rank]; } meta_ = std::make_shared<phi::DenseTensorMeta>(meta);
int Size() const { return size_; } }
void SetStopGradient(size_t rank, bool stop_gradient = true) { bool HasTensorMeta() const { return meta_ && meta_.get(); }
stop_gradient_.at(rank) = stop_gradient; const phi::DenseTensorMeta& GetTensorMeta() const {
if (!HasTensorMeta()) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"meta_ of GradSlotMeta has not been initialized yet."
"You're expected to check Edge availability with HasTensorMeta()"
"before calling GetTensorMeta() interface."));
}
return *meta_.get();
} }
private: private:
int size_{-1}; bool stop_gradient_{false};
std::vector<bool> stop_gradient_{false}; std::shared_ptr<phi::DenseTensorMeta> meta_ = nullptr;
}; };
class GradNodeBase { class GradNodeBase {
...@@ -95,8 +102,12 @@ class GradNodeBase { ...@@ -95,8 +102,12 @@ class GradNodeBase {
* is better choice to fit this format. * is better choice to fit this format.
* **/ * **/
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) = 0; const std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) = 0;
virtual void ClearTensorWrappers() = 0;
virtual bool IsTensorWrappersCleared() = 0;
/** /**
* AddEdges is designed to set input tensors' backward Node as current * AddEdges is designed to set input tensors' backward Node as current
* node's Edges. * node's Edges.
...@@ -108,25 +119,30 @@ class GradNodeBase { ...@@ -108,25 +119,30 @@ class GradNodeBase {
void AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id); void AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id);
void AddEdges(AutogradMeta* meta, size_t slot_id); void AddEdges(AutogradMeta* meta, size_t slot_id);
/** // adj_edges were moved inside OutputMeta(), so no available direct access
* GetEdges is designed to get all edges of current node**/ // from GradNodeBase.
const std::vector<std::vector<Edge>>& GetEdges() const; // To access Edges, get GradSlotMeta by calling OutputMeta(), then use
// slot_meta.GetEdge()
/** /**
* Get Input Meta of current Grad node**/ * Get Input Meta of current Grad node**/
const std::vector<GradSlotMeta>& InputMeta() const; const std::vector<std::vector<GradSlotMeta>>& InputMeta() const;
/** /**
* Get Output Meta of current Grad node**/ * Get Output Meta of current Grad node**/
const std::vector<GradSlotMeta>& OutputMeta() const; const std::vector<std::vector<GradSlotMeta>>& OutputMeta() const;
/** /**
* Set bwd ins and outs info with forward vars * Set bwd ins and outs info with forward vars
* **/ * **/
void SetGradInMeta(std::vector<AutogradMeta*>* fwd_out, size_t slot_rank); void SetGradInMeta(const std::vector<paddle::experimental::Tensor>& fwd_out,
void SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank); size_t slot_rank);
void SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
size_t slot_rank);
void SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in, size_t slot_rank); void SetGradOutMeta(const std::vector<paddle::experimental::Tensor>& fwd_in,
void SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank); size_t slot_rank);
void SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
size_t slot_rank);
/** /**
* Default setters for Grad in/out meta this should be used for same special * Default setters for Grad in/out meta this should be used for same special
...@@ -158,11 +174,21 @@ class GradNodeBase { ...@@ -158,11 +174,21 @@ class GradNodeBase {
std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks( std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors); const std::vector<std::vector<paddle::experimental::Tensor>>& tensors);
/**
* Handle Complex - Real Type Promotion
* **/
void HandleComplexGradToRealGrad(
std::vector<std::vector<paddle::experimental::Tensor>>* out_grads);
bool NeedComplexToRealConversion() { return need_complex_to_real_; }
virtual std::string name() { return "GradNodeBase"; } virtual std::string name() { return "GradNodeBase"; }
private: /**
// TODO(jiabin): Use SmallVector instead after merge PR from develop * GetEdges is designed to get all edges of current node**/
const std::vector<std::vector<Edge>>& GetEdges() const;
private:
// TODO(zhanlve): Merge adj_edges_ into GradOutMeta
// Edges recorded the backward related node info, which indicate all edges // Edges recorded the backward related node info, which indicate all edges
// linked // linked
// by this Grad Node. // by this Grad Node.
...@@ -170,10 +196,10 @@ class GradNodeBase { ...@@ -170,10 +196,10 @@ class GradNodeBase {
std::vector<std::vector<Edge>> adj_edges_; std::vector<std::vector<Edge>> adj_edges_;
// bwd_out_meta_ is used to record Grad output info for backward // bwd_out_meta_ is used to record Grad output info for backward
std::vector<GradSlotMeta> bwd_out_meta_; std::vector<std::vector<GradSlotMeta>> bwd_out_meta_;
// bwd_in_meta_ used to record Grad input info for backward // bwd_in_meta_ used to record Grad input info for backward
std::vector<GradSlotMeta> bwd_in_meta_; std::vector<std::vector<GradSlotMeta>> bwd_in_meta_;
// Gradient Hooks // Gradient Hooks
// Customer may register a list of hooks which will be called in order during // Customer may register a list of hooks which will be called in order during
// backward // backward
...@@ -184,6 +210,8 @@ class GradNodeBase { ...@@ -184,6 +210,8 @@ class GradNodeBase {
/* hook */ std::shared_ptr<TensorHook>>> /* hook */ std::shared_ptr<TensorHook>>>
gradient_hooks_; gradient_hooks_;
// We handle complex to real conversion only if any complex GradIn is involved
bool need_complex_to_real_ = false;
int64_t next_hook_id_{0}; int64_t next_hook_id_{0};
}; };
......
...@@ -21,6 +21,11 @@ ...@@ -21,6 +21,11 @@
namespace egr { namespace egr {
void GradTensorHolder::SetBufferSlotRankZeros(size_t slot_id, size_t rank) {
buffer_[slot_id][rank] =
paddle::experimental::zeros_like(buffer_[slot_id][rank]);
}
void GradTensorHolder::add(size_t slot_id, size_t rank, void GradTensorHolder::add(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t, const paddle::experimental::Tensor& t,
bool fill_one) { bool fill_one) {
......
...@@ -26,12 +26,13 @@ namespace egr { ...@@ -26,12 +26,13 @@ namespace egr {
* GradTensorHolder should have as same format as forward output **/ * GradTensorHolder should have as same format as forward output **/
class GradTensorHolder { class GradTensorHolder {
public: public:
explicit GradTensorHolder(const std::vector<GradSlotMeta>& meta) { explicit GradTensorHolder(
VLOG(7) << "Init GradTensorHolder with meta size: " << meta.size(); const std::vector<std::vector<GradSlotMeta>>& metas) {
buffer_.resize(meta.size()); VLOG(7) << "Init GradTensorHolder with meta size: " << metas.size();
buffer_.resize(metas.size());
for (size_t i = 0; i < buffer_.size(); i++) { for (size_t i = 0; i < buffer_.size(); i++) {
VLOG(7) << "Init GradTensorHolder with meta rank: " << meta[i].Size(); VLOG(7) << "Init GradTensorHolder with meta rank: " << metas[i].size();
buffer_[i].resize(meta[i].Size()); buffer_[i].resize(metas[i].size());
} }
} }
...@@ -56,6 +57,8 @@ class GradTensorHolder { ...@@ -56,6 +57,8 @@ class GradTensorHolder {
return buffer_; return buffer_;
} }
void SetBufferSlotRankZeros(size_t slot_id, size_t rank);
private: private:
std::vector<std::vector<paddle::experimental::Tensor>> buffer_; std::vector<std::vector<paddle::experimental::Tensor>> buffer_;
}; };
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册