未验证 提交 778ea4ec 编写于 作者: C Chen Weihang 提交者: GitHub

[Eager] Polish grad code details (#42536)

* polish grad details

* polish detail by comment
上级 13bcb7cd
......@@ -66,68 +66,69 @@ class GeneralGrad {
"stop_gradient=True.",
msg, i));
if (is_no_grad_vars) {
(no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta;
(no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
} else { // normal input
(input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta;
(input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
}
}
}
}
// Purify potential_startup_nodes, remove nodes those are the same as
// Purify potential_startup_nodes_, remove nodes those are the same as
// input_target_nodes
void PurifyPotentialStartUpNodes() {
VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map.empty()) return;
if (input_target_nodes_inputmeta_map_.empty()) return;
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : potential_startup_nodes) {
auto iter = input_target_nodes_inputmeta_map.find(startup_op);
if (iter != input_target_nodes_inputmeta_map.end()) {
for (auto startup_op : potential_startup_nodes_) {
auto iter = input_target_nodes_inputmeta_map_.find(startup_op);
if (iter != input_target_nodes_inputmeta_map_.end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first);
}
}
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes.erase(nodes);
potential_startup_nodes_.erase(nodes);
}
}
}
// Remove some nodes those doesn't need to be
// stored in potential_stop_nodes、potential_startup_nodes
// stored in potential_stop_nodes_、potential_startup_nodes_
void UpdateGraphInfo() {
// Updated potential_sotp_nodes by depending_nodes,
// Updated potential_sotp_nodes by depending_nodes_,
// make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> _startup_ops;
std::unordered_set<GradNodeBase*> startup_ops;
VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) {
for (auto& target_nodes_inputmeta_pair :
input_target_nodes_inputmeta_map_) {
queue.emplace(target_nodes_inputmeta_pair.first);
}
while (!queue.empty()) {
auto* target_node = queue.front();
queue.pop();
if (!(depending_nodes)[target_node].empty()) {
auto precedding_nodes = (depending_nodes)[target_node];
if (!(depending_nodes_)[target_node].empty()) {
auto precedding_nodes = (depending_nodes_)[target_node];
for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes);
if (potential_stop_nodes.find(pre_nodes) !=
potential_stop_nodes.end()) {
potential_stop_nodes.erase(pre_nodes);
if (potential_stop_nodes_.find(pre_nodes) !=
potential_stop_nodes_.end()) {
potential_stop_nodes_.erase(pre_nodes);
}
}
} else { // startup_ops have no precedding nodes
VLOG(6) << "Emplace _startup_ops";
_startup_ops.emplace(target_node);
VLOG(6) << "Emplace startup_ops";
startup_ops.emplace(target_node);
}
}
// Purify potential_startup_nodes again, remove some
// Purify potential_startup_nodes_ again, remove some
// potential startup_nodes that unreach to input target nodes
if (!_startup_ops.empty()) {
if (!startup_ops.empty()) {
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto node : potential_startup_nodes) {
if (_startup_ops.count(node) == 0) {
for (auto node : potential_startup_nodes_) {
if (startup_ops.count(node) == 0) {
VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
potential_startup_nodes_to_be_erased.emplace(node);
}
......@@ -135,14 +136,14 @@ class GeneralGrad {
if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto node : potential_startup_nodes_to_be_erased) {
VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
potential_startup_nodes.erase(node);
potential_startup_nodes_.erase(node);
}
}
}
}
// Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes、potential_stop_nodes、potential_startup_nodes
// record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
......@@ -164,9 +165,9 @@ class GeneralGrad {
visited.insert(node);
// Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes
// all the next_node will be marked in potential_stop_nodes_
bool is_potential_stop_nodes =
input_target_nodes_inputmeta_map.count(node);
input_target_nodes_inputmeta_map_.count(node);
// Find and append next nodes
const paddle::small_vector<std::vector<GradSlotMeta>,
......@@ -186,40 +187,41 @@ class GeneralGrad {
// all the next_nodes of current node will be inserted to
// potential_stop_node
if (is_potential_stop_nodes) {
potential_stop_nodes.emplace(next_node);
potential_stop_nodes_.emplace(next_node);
}
// Update in_degree
if (!node_in_degree_map.count(next_node))
if (!node_in_degree_map.count(next_node)) {
node_in_degree_map[next_node] = 0;
}
node_in_degree_map[next_node]++;
// Record depending relationship
(depending_nodes)[next_node].emplace(node);
(depending_nodes_)[next_node].emplace(node);
queue.push(next_node);
}
}
}
// Update Graph Info, remove some nodes in
// potential_stop_nodes、potential_startup_nodes
// potential_stop_nodes_、potential_startup_nodes_
UpdateGraphInfo();
}
void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
std::queue<GradNodeBase*> tmp_queue;
for (auto nodes : potential_startup_nodes) {
for (auto nodes : potential_startup_nodes_) {
tmp_queue.emplace(nodes);
}
tmp_queue.swap(*queue);
}
// Set result for input target grad_var when potential_startup_nodes is empty
// Set result for input target grad_var when potential_startup_nodes_ is empty
void SetResultForInputTargetVar(
const std::unordered_map<GradNodeBase*,
std::unique_ptr<GradTensorHolder>>&
node_input_buffers_dict) {
if (potential_startup_nodes.size() == 0) {
for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) {
if (potential_startup_nodes_.size() == 0) {
for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) {
// out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo();
auto iter = node_input_buffers_dict.find(input_target_node.first);
......@@ -227,7 +229,7 @@ class GeneralGrad {
auto& target_result =
(iter->second)->Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[input_target_node.first] = target_result;
results_map_[input_target_node.first] = target_result;
}
}
}
......@@ -236,8 +238,8 @@ class GeneralGrad {
// Set input target grad_var from node_input_buffer by inputmeta
void SetResultForInputTargetVar(GradTensorHolder input_buffers,
GradNodeBase* node) {
auto iter = GetInPutTargetNodesInputMetaMap()->find(node);
if (iter != GetInPutTargetNodesInputMetaMap()->end()) {
auto iter = GetInputTargetNodesInputMetaMap()->find(node);
if (iter != GetInputTargetNodesInputMetaMap()->end()) {
VLOG(6) << "Get target result by by inputmeta";
// out rank_info of forward op
auto rank_info = (iter->second)->OutRankInfo();
......@@ -245,7 +247,7 @@ class GeneralGrad {
auto& target_result =
input_buffers.Buffers()[rank_info.first][rank_info.second];
// save the target result
results_map[node] = target_result;
results_map_[node] = target_result;
}
}
......@@ -271,8 +273,8 @@ class GeneralGrad {
"input";
}
auto iter = results_map.find(target_node);
if (iter != results_map.end()) {
auto iter = results_map_.find(target_node);
if (iter != results_map_.end()) {
// set StopGradient = !create_graph
AutogradMeta* tensor_auto_grad_meta =
EagerUtils::autograd_meta(&(iter->second));
......@@ -303,12 +305,12 @@ class GeneralGrad {
GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
// Get inputs's GradNodes and InputMeta Info
GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
// Purify potential_startup_ops, remove those nodes that are the same as
// Purify potentialstartup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes();
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and
// potential_stop_nodes、potential_startup_nodes
// Record the depending_nodes_ and
// potential_stop_nodes_、potential_startup_nodes_
GetGraphInfoBetweenTargets(*queue);
// Reset queue. Queue is empty only when
// 1.input equals to output. 2.input can not reach to output.
......@@ -318,34 +320,34 @@ class GeneralGrad {
}
bool IsPotentialStopNodes(GradNodeBase* node) {
return potential_stop_nodes.count(node);
return potential_stop_nodes_.count(node);
}
std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetNoGradVarNodesInputMetaMap() {
return &no_grad_var_nodes_inputmeta_map;
return &no_grad_var_nodes_inputmeta_map_;
}
std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetInPutTargetNodesInputMetaMap() {
return &input_target_nodes_inputmeta_map;
GetInputTargetNodesInputMetaMap() {
return &input_target_nodes_inputmeta_map_;
}
std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
return &potential_stop_nodes;
return &potential_stop_nodes_;
}
std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
return &potential_startup_nodes;
return &potential_startup_nodes_;
}
void Clear() {
no_grad_var_nodes_inputmeta_map.clear();
input_target_nodes_inputmeta_map.clear();
potential_startup_nodes.clear();
potential_stop_nodes.clear();
depending_nodes.clear();
results_map.clear();
no_grad_var_nodes_inputmeta_map_.clear();
input_target_nodes_inputmeta_map_.clear();
potential_startup_nodes_.clear();
potential_stop_nodes_.clear();
depending_nodes_.clear();
results_map_.clear();
copied_grad_nodes_.clear();
orig_to_copied_node_mapping_.clear();
}
......@@ -426,18 +428,18 @@ class GeneralGrad {
static GeneralGrad* general_grad_;
// no_grad_vars's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
no_grad_var_nodes_inputmeta_map;
no_grad_var_nodes_inputmeta_map_;
// inputs's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
input_target_nodes_inputmeta_map;
input_target_nodes_inputmeta_map_;
// Record all the potential startup_nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_startup_nodes;
std::unordered_set<GradNodeBase*> potential_startup_nodes_;
// Record all the potential stop nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_stop_nodes;
std::unordered_set<GradNodeBase*> potential_stop_nodes_;
std::unordered_map<GradNodeBase* /* next node */,
std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
depending_nodes_;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map_;
std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
......@@ -619,7 +621,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// GradTensorHolder will initialize another tensor with same tensortype,
// datatype and dims but filled with 1.0
node_input_buffers_dict[grad_node]->CopyValueFromTensor(
input_info.first, input_info.second, tensor, true /*fill_one=true*/);
input_info.first, input_info.second, tensor, /*fill_one=*/true);
}
// Prepare queue, potential startup_nodes
......@@ -657,7 +659,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG(6) << "Running GradNode:" << node->name();
paddle::platform::RecordEvent node_record_event(
std::string((*node).name()) + " grad_node",
std::string((*node).name()),
paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) {
......@@ -667,14 +669,15 @@ std::vector<paddle::experimental::Tensor> RunBackward(
queue.pop();
// Run node: This is where Hook happens
PADDLE_ENFORCE(
node_input_buffers_dict.count(node),
auto node_input_buffer_iter = node_input_buffers_dict.find(node);
PADDLE_ENFORCE_NE(
node_input_buffer_iter, node_input_buffers_dict.end(),
paddle::platform::errors::Fatal(
"Unable to find next node in the GradTensorHolder \n"
"Trying to run Node without configuring its GradTensorHolder."));
std::unique_ptr<GradTensorHolder> node_input_buffer =
std::move(node_input_buffers_dict[node]);
std::move(node_input_buffer_iter->second);
// Set input target grad_var from node_input_buffer by inputmeta
if (!inputs.empty() && is_general_grad) {
......@@ -715,8 +718,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
// TODO(jiabin): Should we erase it or find a more efficient way.
node_input_buffers_dict.erase(node);
node_input_buffers_dict.erase(node_input_buffer_iter);
// Prepare GradTensorHolder for next node
const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
......@@ -736,8 +738,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
auto edge_rank = edge.GetEdgeRankInfo();
// Since we make edge has as same rank as bwd outputs, we indexing them
// with
// the same rank(i, j)
// with the same rank(i, j)
auto next_node_shared = edge.GetMutableGradNode();
// Next node could be nullptr if it is leaf tensor with no
......
......@@ -36,6 +36,31 @@
**/
namespace egr {
static void CheckTensor(const paddle::experimental::Tensor& pre,
const paddle::experimental::Tensor& post) {
if (!pre.initialized() && post.initialized()) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"The tensor in before and after hook are not consistent"));
}
if (pre.initialized() && post.initialized()) {
VLOG(4) << paddle::framework::DataType2String(pre.dtype()) << " "
<< paddle::framework::DataType2String(post.dtype());
PADDLE_ENFORCE_EQ(
pre.dtype(), post.dtype(),
paddle::platform::errors::PermissionDenied(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent",
paddle::framework::DataType2String(pre.dtype()),
paddle::framework::DataType2String(post.dtype())));
PADDLE_ENFORCE_EQ(
pre.place(), post.place(),
paddle::platform::errors::PermissionDenied(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent",
pre.place().DebugString(), post.place().DebugString()));
}
}
GradNodeBase::GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num) {
VLOG(6) << "Construct GradNodeBase";
bwd_in_meta_.resize(bwd_in_slot_num);
......@@ -271,7 +296,7 @@ void GradNodeBase::SetGradOutMeta(
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in_tensor.impl().get());
PADDLE_ENFORCE_NE(dense_tensor->meta().dtype, phi::DataType::UNDEFINED,
PADDLE_ENFORCE_NE(dense_tensor->dtype(), phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal(
"Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
......
......@@ -30,32 +30,23 @@ namespace egr {
* The GradNodeBase will be held in autograd_meta, and it is also a member of
* Edge, which indicates the edge of backward graph.
*
* TODO:(yangzhanlue) GradNodeBase will also in charge of get the correct input
* TODO(yangzhanlue): GradNodeBase will also in charge of get the correct input
* from GradOpDescMaker to GradNodeBase.
*
* NOTE:GradNodeBase has a method named run, this method should be overrided by
* the
* specific derived class, it will prepare backward inputs and double backward's
* depends. Then, it will call C++ API of backward kernel functions to finish
* backward computation.
* NOTE: GradNodeBase has a method named run, this method should be overrided by
* the specific derived class, it will prepare backward inputs and double
* backward's depends. Then, it will call C++ API of backward kernel functions
* to finish backward computation.
*
* NOTE:GradNodeBase holds its own inputs and Outputs
* NOTE: GradNodeBase holds its own inputs and Outputs
*
* Edge is defined to descripe depend of backward, an Edge is what linked
* between two
* node, it should contain a Node and rank of this Node (this is used to
* indicate which
* input of grad this edge belong).
* */
* between two node, it should contain a Node and rank of this Node (this is
* used to indicate which input of grad this edge belong).
**/
class AutogradMeta;
class GradNodeBase;
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators
* whose backward logic is depends on if it has some specific inputs or outputs.
* So, we need a meta info
* to record it's needs.
* **/
class Edge {
public:
// Default constructor for Edges in order to construct it for AutogradMeta
......@@ -64,8 +55,7 @@ class Edge {
// In real use cases we should create Edge from grad node and input rank which
// indicate which edge it is.
// Since we have slot design in operators we will have to locate an edge with
// slot
// and rank.
// slot and rank.
Edge(const std::shared_ptr<GradNodeBase>& grad_node, size_t in_slot_id,
size_t in_rank)
: in_slot_id_(in_slot_id), in_rank_(in_rank), grad_node_(grad_node) {}
......@@ -120,6 +110,12 @@ class Edge {
size_t in_rank_;
std::shared_ptr<GradNodeBase> grad_node_{nullptr};
};
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators whose backward logic is depends on if it has some
* specific inputs or outputs. So, we need a meta info to record it's needs.
**/
class GradSlotMeta {
public:
GradSlotMeta() = default;
......@@ -171,16 +167,13 @@ class GradNodeBase {
/**
* operator() designed to contian the real backward execution logic, it should
* be
* overrided by derived class defined for each operator. It accepts a vector
* of
* Tensor which contains grads input of current operator
* be overrided by derived class defined for each operator. It accepts a
* vector of Tensor which contains grads input of current operator
*
* Note: why we need backward inputs and outputs construct as vector of vector
* of paddle::experimental::Tensor?
* Since all of paddle op composite in form of {"Slot name ", vector<Var>},
* so, vector of vector
* is better choice to fit this format.
* so, vector of vector is better choice to fit this format.
* **/
virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize>
......@@ -294,36 +287,12 @@ class GradNodeBase {
/* slot id */ size_t, /* rank */ size_t,
/* hook */ std::shared_ptr<TensorHook>>>
gradient_hooks_;
int64_t next_hook_id_{0};
// We handle complex to real conversion only if any complex GradIn is involved
bool need_complex_to_real_ = false;
int64_t next_hook_id_{0};
bool is_tensor_wrappers_cleared_ = false;
};
inline void CheckTensor(const paddle::experimental::Tensor& pre,
const paddle::experimental::Tensor& post) {
if (!pre.initialized() && post.initialized()) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"The tensor in before and after hook are not consistent"));
}
if (pre.initialized() && post.initialized()) {
VLOG(4) << paddle::framework::DataType2String(pre.dtype()) << " "
<< paddle::framework::DataType2String(post.dtype());
PADDLE_ENFORCE_EQ(
pre.dtype(), post.dtype(),
paddle::platform::errors::PermissionDenied(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent",
paddle::framework::DataType2String(pre.dtype()),
paddle::framework::DataType2String(post.dtype())));
PADDLE_ENFORCE_EQ(
pre.place(), post.place(),
paddle::platform::errors::PermissionDenied(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent",
pre.place().DebugString(), post.place().DebugString()));
}
}
} // namespace egr
......@@ -88,6 +88,7 @@ class TensorWrapper {
} else {
intermidiate_tensor_.set_impl(tensor.impl());
}
// TODO(jiabin): This may has server performance issue
intermidiate_tensor_.set_name(tensor.name() + "@Saved");
......@@ -118,24 +119,25 @@ class TensorWrapper {
paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;
std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock();
if (new_grad_node) {
VLOG(3) << "Recovered TensorWrapper with GradNode "
<< new_grad_node->name() << " addr: " << new_grad_node.get();
} else {
VLOG(3) << "Recovered TensorWrapper with Empth GradNode";
}
auto* intermediate_autograd_meta =
EagerUtils::unsafe_autograd_meta(intermidiate_tensor_);
auto p_ab_autograd_meta =
std::make_shared<AutogradMeta>(*intermediate_autograd_meta);
if (new_grad_node) {
VLOG(3) << "Recovered TensorWrapper with GradNode "
<< new_grad_node->name() << " addr: " << new_grad_node.get();
p_ab_autograd_meta->SetGradNode(new_grad_node);
} else {
VLOG(3) << "Recovered TensorWrapper with Empth GradNode";
}
recovered_tensor.set_autograd_meta(p_ab_autograd_meta);
return recovered_tensor;
}
}
void clear() { intermidiate_tensor_.reset(); }
private:
void check_inplace_version() {
if (no_need_buffer_) {
VLOG(6) << "There's no need to check inplace_version because "
......@@ -170,8 +172,6 @@ class TensorWrapper {
}
}
void clear() { intermidiate_tensor_.reset(); }
private:
bool full_reserved_ = false;
bool no_need_buffer_ = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册