未验证 提交 778ea4ec 编写于 作者: C Chen Weihang 提交者: GitHub

[Eager] Polish grad code details (#42536)

* polish grad details

* polish detail by comment
上级 13bcb7cd
...@@ -66,68 +66,69 @@ class GeneralGrad { ...@@ -66,68 +66,69 @@ class GeneralGrad {
"stop_gradient=True.", "stop_gradient=True.",
msg, i)); msg, i));
if (is_no_grad_vars) { if (is_no_grad_vars) {
(no_grad_var_nodes_inputmeta_map)[target_node] = auto_grad_meta; (no_grad_var_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
} else { // normal input } else { // normal input
(input_target_nodes_inputmeta_map)[target_node] = auto_grad_meta; (input_target_nodes_inputmeta_map_)[target_node] = auto_grad_meta;
} }
} }
} }
} }
// Purify potential_startup_nodes, remove nodes those are the same as // Purify potential_startup_nodes_, remove nodes those are the same as
// input_target_nodes // input_target_nodes
void PurifyPotentialStartUpNodes() { void PurifyPotentialStartUpNodes() {
VLOG(6) << "Running in PurifyPotentialStartUpNodes"; VLOG(6) << "Running in PurifyPotentialStartUpNodes";
if (input_target_nodes_inputmeta_map.empty()) return; if (input_target_nodes_inputmeta_map_.empty()) return;
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased; std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto startup_op : potential_startup_nodes) { for (auto startup_op : potential_startup_nodes_) {
auto iter = input_target_nodes_inputmeta_map.find(startup_op); auto iter = input_target_nodes_inputmeta_map_.find(startup_op);
if (iter != input_target_nodes_inputmeta_map.end()) { if (iter != input_target_nodes_inputmeta_map_.end()) {
potential_startup_nodes_to_be_erased.emplace(iter->first); potential_startup_nodes_to_be_erased.emplace(iter->first);
} }
} }
if (!potential_startup_nodes_to_be_erased.empty()) { if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto nodes : potential_startup_nodes_to_be_erased) { for (auto nodes : potential_startup_nodes_to_be_erased) {
potential_startup_nodes.erase(nodes); potential_startup_nodes_.erase(nodes);
} }
} }
} }
// Remove some nodes those doesn't need to be // Remove some nodes those doesn't need to be
// stored in potential_stop_nodes、potential_startup_nodes // stored in potential_stop_nodes_、potential_startup_nodes_
void UpdateGraphInfo() { void UpdateGraphInfo() {
// Updated potential_sotp_nodes by depending_nodes, // Updated potential_sotp_nodes by depending_nodes_,
// make sure the path from root to target_node is ok // make sure the path from root to target_node is ok
std::unordered_set<GradNodeBase*> _startup_ops; std::unordered_set<GradNodeBase*> startup_ops;
VLOG(6) << "Running in UpdateGraphInfo"; VLOG(6) << "Running in UpdateGraphInfo";
std::queue<GradNodeBase*> queue; std::queue<GradNodeBase*> queue;
for (auto& target_nodes_inputmeta_pair : input_target_nodes_inputmeta_map) { for (auto& target_nodes_inputmeta_pair :
input_target_nodes_inputmeta_map_) {
queue.emplace(target_nodes_inputmeta_pair.first); queue.emplace(target_nodes_inputmeta_pair.first);
} }
while (!queue.empty()) { while (!queue.empty()) {
auto* target_node = queue.front(); auto* target_node = queue.front();
queue.pop(); queue.pop();
if (!(depending_nodes)[target_node].empty()) { if (!(depending_nodes_)[target_node].empty()) {
auto precedding_nodes = (depending_nodes)[target_node]; auto precedding_nodes = (depending_nodes_)[target_node];
for (auto pre_nodes : precedding_nodes) { for (auto pre_nodes : precedding_nodes) {
queue.emplace(pre_nodes); queue.emplace(pre_nodes);
if (potential_stop_nodes.find(pre_nodes) != if (potential_stop_nodes_.find(pre_nodes) !=
potential_stop_nodes.end()) { potential_stop_nodes_.end()) {
potential_stop_nodes.erase(pre_nodes); potential_stop_nodes_.erase(pre_nodes);
} }
} }
} else { // startup_ops have no precedding nodes } else { // startup_ops have no precedding nodes
VLOG(6) << "Emplace _startup_ops"; VLOG(6) << "Emplace startup_ops";
_startup_ops.emplace(target_node); startup_ops.emplace(target_node);
} }
} }
// Purify potential_startup_nodes again, remove some // Purify potential_startup_nodes_ again, remove some
// potential startup_nodes that unreach to input target nodes // potential startup_nodes that unreach to input target nodes
if (!_startup_ops.empty()) { if (!startup_ops.empty()) {
std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased; std::unordered_set<GradNodeBase*> potential_startup_nodes_to_be_erased;
for (auto node : potential_startup_nodes) { for (auto node : potential_startup_nodes_) {
if (_startup_ops.count(node) == 0) { if (startup_ops.count(node) == 0) {
VLOG(6) << "Set up potential_startup_nodes_to_be_erased"; VLOG(6) << "Set up potential_startup_nodes_to_be_erased";
potential_startup_nodes_to_be_erased.emplace(node); potential_startup_nodes_to_be_erased.emplace(node);
} }
...@@ -135,14 +136,14 @@ class GeneralGrad { ...@@ -135,14 +136,14 @@ class GeneralGrad {
if (!potential_startup_nodes_to_be_erased.empty()) { if (!potential_startup_nodes_to_be_erased.empty()) {
for (auto node : potential_startup_nodes_to_be_erased) { for (auto node : potential_startup_nodes_to_be_erased) {
VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased"; VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased";
potential_startup_nodes.erase(node); potential_startup_nodes_.erase(node);
} }
} }
} }
} }
// Get Graph Info Betweent input target GradNode and outputs, // Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes、potential_stop_nodes、potential_startup_nodes // record depending_nodes_、potential_stop_nodes_、potential_startup_nodes_
void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) { void GetGraphInfoBetweenTargets(const std::queue<GradNodeBase*>& init_queue) {
VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; VLOG(6) << "Runing In GetGraphInfoBetweenTargets";
...@@ -164,9 +165,9 @@ class GeneralGrad { ...@@ -164,9 +165,9 @@ class GeneralGrad {
visited.insert(node); visited.insert(node);
// Check node is target_nodes or not, if node is not target_node, // Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes // all the next_node will be marked in potential_stop_nodes_
bool is_potential_stop_nodes = bool is_potential_stop_nodes =
input_target_nodes_inputmeta_map.count(node); input_target_nodes_inputmeta_map_.count(node);
// Find and append next nodes // Find and append next nodes
const paddle::small_vector<std::vector<GradSlotMeta>, const paddle::small_vector<std::vector<GradSlotMeta>,
...@@ -186,40 +187,41 @@ class GeneralGrad { ...@@ -186,40 +187,41 @@ class GeneralGrad {
// all the next_nodes of current node will be inserted to // all the next_nodes of current node will be inserted to
// potential_stop_node // potential_stop_node
if (is_potential_stop_nodes) { if (is_potential_stop_nodes) {
potential_stop_nodes.emplace(next_node); potential_stop_nodes_.emplace(next_node);
} }
// Update in_degree // Update in_degree
if (!node_in_degree_map.count(next_node)) if (!node_in_degree_map.count(next_node)) {
node_in_degree_map[next_node] = 0; node_in_degree_map[next_node] = 0;
}
node_in_degree_map[next_node]++; node_in_degree_map[next_node]++;
// Record depending relationship // Record depending relationship
(depending_nodes)[next_node].emplace(node); (depending_nodes_)[next_node].emplace(node);
queue.push(next_node); queue.push(next_node);
} }
} }
} }
// Update Graph Info, remove some nodes in // Update Graph Info, remove some nodes in
// potential_stop_nodes、potential_startup_nodes // potential_stop_nodes_、potential_startup_nodes_
UpdateGraphInfo(); UpdateGraphInfo();
} }
void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) { void ModifyReadyQueue(std::queue<GradNodeBase*>* queue) {
std::queue<GradNodeBase*> tmp_queue; std::queue<GradNodeBase*> tmp_queue;
for (auto nodes : potential_startup_nodes) { for (auto nodes : potential_startup_nodes_) {
tmp_queue.emplace(nodes); tmp_queue.emplace(nodes);
} }
tmp_queue.swap(*queue); tmp_queue.swap(*queue);
} }
// Set result for input target grad_var when potential_startup_nodes is empty // Set result for input target grad_var when potential_startup_nodes_ is empty
void SetResultForInputTargetVar( void SetResultForInputTargetVar(
const std::unordered_map<GradNodeBase*, const std::unordered_map<GradNodeBase*,
std::unique_ptr<GradTensorHolder>>& std::unique_ptr<GradTensorHolder>>&
node_input_buffers_dict) { node_input_buffers_dict) {
if (potential_startup_nodes.size() == 0) { if (potential_startup_nodes_.size() == 0) {
for (auto input_target_node : *GetInPutTargetNodesInputMetaMap()) { for (auto input_target_node : *GetInputTargetNodesInputMetaMap()) {
// out rank_info of forward op // out rank_info of forward op
auto rank_info = input_target_node.second->OutRankInfo(); auto rank_info = input_target_node.second->OutRankInfo();
auto iter = node_input_buffers_dict.find(input_target_node.first); auto iter = node_input_buffers_dict.find(input_target_node.first);
...@@ -227,7 +229,7 @@ class GeneralGrad { ...@@ -227,7 +229,7 @@ class GeneralGrad {
auto& target_result = auto& target_result =
(iter->second)->Buffers()[rank_info.first][rank_info.second]; (iter->second)->Buffers()[rank_info.first][rank_info.second];
// save the target result // save the target result
results_map[input_target_node.first] = target_result; results_map_[input_target_node.first] = target_result;
} }
} }
} }
...@@ -236,8 +238,8 @@ class GeneralGrad { ...@@ -236,8 +238,8 @@ class GeneralGrad {
// Set input target grad_var from node_input_buffer by inputmeta // Set input target grad_var from node_input_buffer by inputmeta
void SetResultForInputTargetVar(GradTensorHolder input_buffers, void SetResultForInputTargetVar(GradTensorHolder input_buffers,
GradNodeBase* node) { GradNodeBase* node) {
auto iter = GetInPutTargetNodesInputMetaMap()->find(node); auto iter = GetInputTargetNodesInputMetaMap()->find(node);
if (iter != GetInPutTargetNodesInputMetaMap()->end()) { if (iter != GetInputTargetNodesInputMetaMap()->end()) {
VLOG(6) << "Get target result by by inputmeta"; VLOG(6) << "Get target result by by inputmeta";
// out rank_info of forward op // out rank_info of forward op
auto rank_info = (iter->second)->OutRankInfo(); auto rank_info = (iter->second)->OutRankInfo();
...@@ -245,7 +247,7 @@ class GeneralGrad { ...@@ -245,7 +247,7 @@ class GeneralGrad {
auto& target_result = auto& target_result =
input_buffers.Buffers()[rank_info.first][rank_info.second]; input_buffers.Buffers()[rank_info.first][rank_info.second];
// save the target result // save the target result
results_map[node] = target_result; results_map_[node] = target_result;
} }
} }
...@@ -271,8 +273,8 @@ class GeneralGrad { ...@@ -271,8 +273,8 @@ class GeneralGrad {
"input"; "input";
} }
auto iter = results_map.find(target_node); auto iter = results_map_.find(target_node);
if (iter != results_map.end()) { if (iter != results_map_.end()) {
// set StopGradient = !create_graph // set StopGradient = !create_graph
AutogradMeta* tensor_auto_grad_meta = AutogradMeta* tensor_auto_grad_meta =
EagerUtils::autograd_meta(&(iter->second)); EagerUtils::autograd_meta(&(iter->second));
...@@ -303,12 +305,12 @@ class GeneralGrad { ...@@ -303,12 +305,12 @@ class GeneralGrad {
GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */); GetTargetNodesInfo(no_grad_vars, true /* is_no_grad_vars */);
// Get inputs's GradNodes and InputMeta Info // Get inputs's GradNodes and InputMeta Info
GetTargetNodesInfo(inputs, false /* is_no_grad_vars */); GetTargetNodesInfo(inputs, false /* is_no_grad_vars */);
// Purify potential_startup_ops, remove those nodes that are the same as // Purify potentialstartup_ops, remove those nodes that are the same as
// input_target_nodes // input_target_nodes
PurifyPotentialStartUpNodes(); PurifyPotentialStartUpNodes();
// Get Graph Info Betweent input target gradnode and outputs // Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and // Record the depending_nodes_ and
// potential_stop_nodes、potential_startup_nodes // potential_stop_nodes_、potential_startup_nodes_
GetGraphInfoBetweenTargets(*queue); GetGraphInfoBetweenTargets(*queue);
// Reset queue. Queue is empty only when // Reset queue. Queue is empty only when
// 1.input equals to output. 2.input can not reach to output. // 1.input equals to output. 2.input can not reach to output.
...@@ -318,34 +320,34 @@ class GeneralGrad { ...@@ -318,34 +320,34 @@ class GeneralGrad {
} }
bool IsPotentialStopNodes(GradNodeBase* node) { bool IsPotentialStopNodes(GradNodeBase* node) {
return potential_stop_nodes.count(node); return potential_stop_nodes_.count(node);
} }
std::unordered_map<GradNodeBase*, AutogradMeta*>* std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetNoGradVarNodesInputMetaMap() { GetNoGradVarNodesInputMetaMap() {
return &no_grad_var_nodes_inputmeta_map; return &no_grad_var_nodes_inputmeta_map_;
} }
std::unordered_map<GradNodeBase*, AutogradMeta*>* std::unordered_map<GradNodeBase*, AutogradMeta*>*
GetInPutTargetNodesInputMetaMap() { GetInputTargetNodesInputMetaMap() {
return &input_target_nodes_inputmeta_map; return &input_target_nodes_inputmeta_map_;
} }
std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() { std::unordered_set<GradNodeBase*>* GetPotentialStopNodes() {
return &potential_stop_nodes; return &potential_stop_nodes_;
} }
std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() { std::unordered_set<GradNodeBase*>* GetPotentialStartupNodes() {
return &potential_startup_nodes; return &potential_startup_nodes_;
} }
void Clear() { void Clear() {
no_grad_var_nodes_inputmeta_map.clear(); no_grad_var_nodes_inputmeta_map_.clear();
input_target_nodes_inputmeta_map.clear(); input_target_nodes_inputmeta_map_.clear();
potential_startup_nodes.clear(); potential_startup_nodes_.clear();
potential_stop_nodes.clear(); potential_stop_nodes_.clear();
depending_nodes.clear(); depending_nodes_.clear();
results_map.clear(); results_map_.clear();
copied_grad_nodes_.clear(); copied_grad_nodes_.clear();
orig_to_copied_node_mapping_.clear(); orig_to_copied_node_mapping_.clear();
} }
...@@ -426,18 +428,18 @@ class GeneralGrad { ...@@ -426,18 +428,18 @@ class GeneralGrad {
static GeneralGrad* general_grad_; static GeneralGrad* general_grad_;
// no_grad_vars's GradNode and GradNode's InputMeta. // no_grad_vars's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */> std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
no_grad_var_nodes_inputmeta_map; no_grad_var_nodes_inputmeta_map_;
// inputs's GradNode and GradNode's InputMeta. // inputs's GradNode and GradNode's InputMeta.
std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */> std::unordered_map<GradNodeBase*, AutogradMeta* /* InputMeta */>
input_target_nodes_inputmeta_map; input_target_nodes_inputmeta_map_;
// Record all the potential startup_nodes, will be changed. // Record all the potential startup_nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_startup_nodes; std::unordered_set<GradNodeBase*> potential_startup_nodes_;
// Record all the potential stop nodes, will be changed. // Record all the potential stop nodes, will be changed.
std::unordered_set<GradNodeBase*> potential_stop_nodes; std::unordered_set<GradNodeBase*> potential_stop_nodes_;
std::unordered_map<GradNodeBase* /* next node */, std::unordered_map<GradNodeBase* /* next node */,
std::unordered_set<GradNodeBase*> /* pre nodes */> std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes; depending_nodes_;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map; std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map_;
std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_; std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>> std::unordered_map<GradNodeBase*, std::shared_ptr<GradNodeBase>>
...@@ -619,7 +621,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -619,7 +621,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// GradTensorHolder will initialize another tensor with same tensortype, // GradTensorHolder will initialize another tensor with same tensortype,
// datatype and dims but filled with 1.0 // datatype and dims but filled with 1.0
node_input_buffers_dict[grad_node]->CopyValueFromTensor( node_input_buffers_dict[grad_node]->CopyValueFromTensor(
input_info.first, input_info.second, tensor, true /*fill_one=true*/); input_info.first, input_info.second, tensor, /*fill_one=*/true);
} }
// Prepare queue, potential startup_nodes // Prepare queue, potential startup_nodes
...@@ -657,7 +659,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -657,7 +659,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG(6) << "Running GradNode:" << node->name(); VLOG(6) << "Running GradNode:" << node->name();
paddle::platform::RecordEvent node_record_event( paddle::platform::RecordEvent node_record_event(
std::string((*node).name()) + " grad_node", std::string((*node).name()),
paddle::platform::TracerEventType::Operator, 1); paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) { if (queue.size() > 1 && node_in_degree_map[node] != 0) {
...@@ -667,14 +669,15 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -667,14 +669,15 @@ std::vector<paddle::experimental::Tensor> RunBackward(
queue.pop(); queue.pop();
// Run node: This is where Hook happens // Run node: This is where Hook happens
PADDLE_ENFORCE( auto node_input_buffer_iter = node_input_buffers_dict.find(node);
node_input_buffers_dict.count(node), PADDLE_ENFORCE_NE(
node_input_buffer_iter, node_input_buffers_dict.end(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Unable to find next node in the GradTensorHolder \n" "Unable to find next node in the GradTensorHolder \n"
"Trying to run Node without configuring its GradTensorHolder.")); "Trying to run Node without configuring its GradTensorHolder."));
std::unique_ptr<GradTensorHolder> node_input_buffer = std::unique_ptr<GradTensorHolder> node_input_buffer =
std::move(node_input_buffers_dict[node]); std::move(node_input_buffer_iter->second);
// Set input target grad_var from node_input_buffer by inputmeta // Set input target grad_var from node_input_buffer by inputmeta
if (!inputs.empty() && is_general_grad) { if (!inputs.empty() && is_general_grad) {
...@@ -715,8 +718,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -715,8 +718,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
} }
// TODO(jiabin): Should we erase it or find a more efficient way. // TODO(jiabin): Should we erase it or find a more efficient way.
node_input_buffers_dict.erase(node_input_buffer_iter);
node_input_buffers_dict.erase(node);
// Prepare GradTensorHolder for next node // Prepare GradTensorHolder for next node
const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>& const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
...@@ -736,8 +738,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -736,8 +738,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
} }
auto edge_rank = edge.GetEdgeRankInfo(); auto edge_rank = edge.GetEdgeRankInfo();
// Since we make edge has as same rank as bwd outputs, we indexing them // Since we make edge has as same rank as bwd outputs, we indexing them
// with // with the same rank(i, j)
// the same rank(i, j)
auto next_node_shared = edge.GetMutableGradNode(); auto next_node_shared = edge.GetMutableGradNode();
// Next node could be nullptr if it is leaf tensor with no // Next node could be nullptr if it is leaf tensor with no
......
...@@ -36,6 +36,31 @@ ...@@ -36,6 +36,31 @@
**/ **/
namespace egr { namespace egr {
static void CheckTensor(const paddle::experimental::Tensor& pre,
const paddle::experimental::Tensor& post) {
if (!pre.initialized() && post.initialized()) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"The tensor in before and after hook are not consistent"));
}
if (pre.initialized() && post.initialized()) {
VLOG(4) << paddle::framework::DataType2String(pre.dtype()) << " "
<< paddle::framework::DataType2String(post.dtype());
PADDLE_ENFORCE_EQ(
pre.dtype(), post.dtype(),
paddle::platform::errors::PermissionDenied(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent",
paddle::framework::DataType2String(pre.dtype()),
paddle::framework::DataType2String(post.dtype())));
PADDLE_ENFORCE_EQ(
pre.place(), post.place(),
paddle::platform::errors::PermissionDenied(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent",
pre.place().DebugString(), post.place().DebugString()));
}
}
GradNodeBase::GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num) { GradNodeBase::GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num) {
VLOG(6) << "Construct GradNodeBase"; VLOG(6) << "Construct GradNodeBase";
bwd_in_meta_.resize(bwd_in_slot_num); bwd_in_meta_.resize(bwd_in_slot_num);
...@@ -271,7 +296,7 @@ void GradNodeBase::SetGradOutMeta( ...@@ -271,7 +296,7 @@ void GradNodeBase::SetGradOutMeta(
// Only Copy Meta // Only Copy Meta
phi::DenseTensor* dense_tensor = phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in_tensor.impl().get()); static_cast<phi::DenseTensor*>(fwd_in_tensor.impl().get());
PADDLE_ENFORCE_NE(dense_tensor->meta().dtype, phi::DataType::UNDEFINED, PADDLE_ENFORCE_NE(dense_tensor->dtype(), phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Attempting to copy DenseTensorMeta " "Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED," "with phi::DataType::UNDEFINED,"
......
...@@ -30,32 +30,23 @@ namespace egr { ...@@ -30,32 +30,23 @@ namespace egr {
* The GradNodeBase will be held in autograd_meta, and it is also a member of * The GradNodeBase will be held in autograd_meta, and it is also a member of
* Edge, which indicates the edge of backward graph. * Edge, which indicates the edge of backward graph.
* *
* TODO:(yangzhanlue) GradNodeBase will also in charge of get the correct input * TODO(yangzhanlue): GradNodeBase will also in charge of get the correct input
* from GradOpDescMaker to GradNodeBase. * from GradOpDescMaker to GradNodeBase.
* *
* NOTE:GradNodeBase has a method named run, this method should be overrided by * NOTE: GradNodeBase has a method named run, this method should be overrided by
* the * the specific derived class, it will prepare backward inputs and double
* specific derived class, it will prepare backward inputs and double backward's * backward's depends. Then, it will call C++ API of backward kernel functions
* depends. Then, it will call C++ API of backward kernel functions to finish * to finish backward computation.
* backward computation.
* *
* NOTE:GradNodeBase holds its own inputs and Outputs * NOTE: GradNodeBase holds its own inputs and Outputs
* *
* Edge is defined to descripe depend of backward, an Edge is what linked * Edge is defined to descripe depend of backward, an Edge is what linked
* between two * between two node, it should contain a Node and rank of this Node (this is
* node, it should contain a Node and rank of this Node (this is used to * used to indicate which input of grad this edge belong).
* indicate which **/
* input of grad this edge belong).
* */
class AutogradMeta; class AutogradMeta;
class GradNodeBase; class GradNodeBase;
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators
* whose backward logic is depends on if it has some specific inputs or outputs.
* So, we need a meta info
* to record it's needs.
* **/
class Edge { class Edge {
public: public:
// Default constructor for Edges in order to construct it for AutogradMeta // Default constructor for Edges in order to construct it for AutogradMeta
...@@ -64,8 +55,7 @@ class Edge { ...@@ -64,8 +55,7 @@ class Edge {
// In real use cases we should create Edge from grad node and input rank which // In real use cases we should create Edge from grad node and input rank which
// indicate which edge it is. // indicate which edge it is.
// Since we have slot design in operators we will have to locate an edge with // Since we have slot design in operators we will have to locate an edge with
// slot // slot and rank.
// and rank.
Edge(const std::shared_ptr<GradNodeBase>& grad_node, size_t in_slot_id, Edge(const std::shared_ptr<GradNodeBase>& grad_node, size_t in_slot_id,
size_t in_rank) size_t in_rank)
: in_slot_id_(in_slot_id), in_rank_(in_rank), grad_node_(grad_node) {} : in_slot_id_(in_slot_id), in_rank_(in_rank), grad_node_(grad_node) {}
...@@ -120,6 +110,12 @@ class Edge { ...@@ -120,6 +110,12 @@ class Edge {
size_t in_rank_; size_t in_rank_;
std::shared_ptr<GradNodeBase> grad_node_{nullptr}; std::shared_ptr<GradNodeBase> grad_node_{nullptr};
}; };
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators whose backward logic is depends on if it has some
* specific inputs or outputs. So, we need a meta info to record it's needs.
**/
class GradSlotMeta { class GradSlotMeta {
public: public:
GradSlotMeta() = default; GradSlotMeta() = default;
...@@ -171,16 +167,13 @@ class GradNodeBase { ...@@ -171,16 +167,13 @@ class GradNodeBase {
/** /**
* operator() designed to contian the real backward execution logic, it should * operator() designed to contian the real backward execution logic, it should
* be * be overrided by derived class defined for each operator. It accepts a
* overrided by derived class defined for each operator. It accepts a vector * vector of Tensor which contains grads input of current operator
* of
* Tensor which contains grads input of current operator
* *
* Note: why we need backward inputs and outputs construct as vector of vector * Note: why we need backward inputs and outputs construct as vector of vector
* of paddle::experimental::Tensor? * of paddle::experimental::Tensor?
* Since all of paddle op composite in form of {"Slot name ", vector<Var>}, * Since all of paddle op composite in form of {"Slot name ", vector<Var>},
* so, vector of vector * so, vector of vector is better choice to fit this format.
* is better choice to fit this format.
* **/ * **/
virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>, virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>,
kSlotSmallVectorSize> kSlotSmallVectorSize>
...@@ -294,36 +287,12 @@ class GradNodeBase { ...@@ -294,36 +287,12 @@ class GradNodeBase {
/* slot id */ size_t, /* rank */ size_t, /* slot id */ size_t, /* rank */ size_t,
/* hook */ std::shared_ptr<TensorHook>>> /* hook */ std::shared_ptr<TensorHook>>>
gradient_hooks_; gradient_hooks_;
int64_t next_hook_id_{0};
// We handle complex to real conversion only if any complex GradIn is involved // We handle complex to real conversion only if any complex GradIn is involved
bool need_complex_to_real_ = false; bool need_complex_to_real_ = false;
int64_t next_hook_id_{0};
bool is_tensor_wrappers_cleared_ = false; bool is_tensor_wrappers_cleared_ = false;
}; };
inline void CheckTensor(const paddle::experimental::Tensor& pre,
const paddle::experimental::Tensor& post) {
if (!pre.initialized() && post.initialized()) {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"The tensor in before and after hook are not consistent"));
}
if (pre.initialized() && post.initialized()) {
VLOG(4) << paddle::framework::DataType2String(pre.dtype()) << " "
<< paddle::framework::DataType2String(post.dtype());
PADDLE_ENFORCE_EQ(
pre.dtype(), post.dtype(),
paddle::platform::errors::PermissionDenied(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent",
paddle::framework::DataType2String(pre.dtype()),
paddle::framework::DataType2String(post.dtype())));
PADDLE_ENFORCE_EQ(
pre.place(), post.place(),
paddle::platform::errors::PermissionDenied(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent",
pre.place().DebugString(), post.place().DebugString()));
}
}
} // namespace egr } // namespace egr
...@@ -88,6 +88,7 @@ class TensorWrapper { ...@@ -88,6 +88,7 @@ class TensorWrapper {
} else { } else {
intermidiate_tensor_.set_impl(tensor.impl()); intermidiate_tensor_.set_impl(tensor.impl());
} }
// TODO(jiabin): This may has server performance issue // TODO(jiabin): This may has server performance issue
intermidiate_tensor_.set_name(tensor.name() + "@Saved"); intermidiate_tensor_.set_name(tensor.name() + "@Saved");
...@@ -118,24 +119,25 @@ class TensorWrapper { ...@@ -118,24 +119,25 @@ class TensorWrapper {
paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;
std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock(); std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock();
if (new_grad_node) {
VLOG(3) << "Recovered TensorWrapper with GradNode "
<< new_grad_node->name() << " addr: " << new_grad_node.get();
} else {
VLOG(3) << "Recovered TensorWrapper with Empth GradNode";
}
auto* intermediate_autograd_meta = auto* intermediate_autograd_meta =
EagerUtils::unsafe_autograd_meta(intermidiate_tensor_); EagerUtils::unsafe_autograd_meta(intermidiate_tensor_);
auto p_ab_autograd_meta = auto p_ab_autograd_meta =
std::make_shared<AutogradMeta>(*intermediate_autograd_meta); std::make_shared<AutogradMeta>(*intermediate_autograd_meta);
if (new_grad_node) { if (new_grad_node) {
VLOG(3) << "Recovered TensorWrapper with GradNode "
<< new_grad_node->name() << " addr: " << new_grad_node.get();
p_ab_autograd_meta->SetGradNode(new_grad_node); p_ab_autograd_meta->SetGradNode(new_grad_node);
} else {
VLOG(3) << "Recovered TensorWrapper with Empth GradNode";
} }
recovered_tensor.set_autograd_meta(p_ab_autograd_meta); recovered_tensor.set_autograd_meta(p_ab_autograd_meta);
return recovered_tensor; return recovered_tensor;
} }
} }
void clear() { intermidiate_tensor_.reset(); }
private:
void check_inplace_version() { void check_inplace_version() {
if (no_need_buffer_) { if (no_need_buffer_) {
VLOG(6) << "There's no need to check inplace_version because " VLOG(6) << "There's no need to check inplace_version because "
...@@ -170,8 +172,6 @@ class TensorWrapper { ...@@ -170,8 +172,6 @@ class TensorWrapper {
} }
} }
void clear() { intermidiate_tensor_.reset(); }
private: private:
bool full_reserved_ = false; bool full_reserved_ = false;
bool no_need_buffer_ = false; bool no_need_buffer_ = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册