diff --git a/paddle/fluid/eager/accumulation/accumulation_node.cc b/paddle/fluid/eager/accumulation/accumulation_node.cc index 3a2ec403c0a59aaa23decc72fb9581b5a7f78343..9c4089af092e418d6845864671124917c6498cf1 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.cc +++ b/paddle/fluid/eager/accumulation/accumulation_node.cc @@ -24,7 +24,7 @@ #include "paddle/fluid/platform/errors.h" #include "glog/logging.h" - +DECLARE_bool(retain_grad_for_all_tensor); namespace egr { static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, @@ -39,8 +39,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, } std::vector> GradNodeAccumulation:: -operator()( - const std::vector>& grads) { +operator()(const std::vector>& grads, + bool create_graph) { VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation"; PADDLE_ENFORCE(grads.size() == 1, paddle::platform::errors::Fatal( @@ -62,7 +62,7 @@ operator()( grad_out = grads[0][0]; } - if (!weak_grad_.expired()) { + if (!weak_grad_.expired() && FLAGS_retain_grad_for_all_tensor) { auto grad = weak_grad_.lock(); CopyOrAddTensor(grad.get(), grad_out); } diff --git a/paddle/fluid/eager/accumulation/accumulation_node.h b/paddle/fluid/eager/accumulation/accumulation_node.h index 07fa40165167ce2352018c0e1b1cb08222d5a181..a91a0b6e34c0d9440e3645d1a6982748c4315962 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.h +++ b/paddle/fluid/eager/accumulation/accumulation_node.h @@ -35,8 +35,15 @@ class GradNodeAccumulation : public GradNodeBase { // Functor: perform backward computations virtual std::vector> operator()( - const std::vector>& grads) - override; + const std::vector>& grads, + bool create_graph = false) override; + + void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } + + bool IsTensorWrappersCleared() override { + VLOG(6) << "Do nothing here now"; + return false; + } std::string name() { return "GradNodeAccumulation"; } diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc index 5a2595b9103e4d49845fa8938ee3577b6b3f3f06..0bc998a03a80b7b8a1e486ad68f1575c130d2c1b 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc @@ -145,8 +145,8 @@ void GradNodeScale::SetTensorWrappers_X( void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; } std::vector> GradNodeScale:: -operator()( - const std::vector>& grads) { +operator()(const std::vector>& grads, + bool create_graph) { // 1. Check Output Size PADDLE_ENFORCE( ((grads.size() == 1) && (grads[0].size() == 1)), diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h index 247fde6ed1f869542969b068cdae9f59cedd732a..e263f73a6b8a4a1f9ce23d9b5ca383fd6828016b 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h @@ -39,8 +39,15 @@ class GradNodeScale : public GradNodeBase { // Functor: perform backward computations virtual std::vector> operator()( - const std::vector>& grads) - override; + const std::vector>& grads, + bool create_graph = false) override; + + void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } + + bool IsTensorWrappersCleared() override { + VLOG(6) << "Do nothing here now"; + return false; + } void SetTensorWrappers_X( const std::vector& tensors); diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index bf838b27615028167e35a8e85e7636dd4c834016..d9f201dc9f1e8b9a0296288917b82f3e2903330e 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -2074,7 +2074,8 @@ static std::string GenerateGradNodeCCContents( const char* GRAD_FUNCTION_TEMPLATE = "std::vector> " "GradNode%s::operator()(const " - "std::vector>& grads) {\n%s\n}"; + "std::vector>& grads, " + "bool create_graph) {\n%s\n}"; std::string grad_function_str = paddle::string::Sprintf( GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body); @@ -2109,18 +2110,28 @@ static std::string GenerateGradNodeHeaderContents( "\n" " virtual std::vector> " "operator()(const " - "std::vector>& grads) " + "std::vector>& grads, const " + "bool create_graph = false) " "override;\n" "\n" + " void ClearTensorWrappers() override { \n" + "%s\n" + " is_tensor_wrappers_cleared = true;\n" + " }\n" " std::string name() override { return \" GradNode%s \"; } \n " "\n" " // SetX, SetY, ...\n" "%s\n" " // SetAttrMap\n" "%s\n" + " bool IsTensorWrappersCleared() override { \n" + " return is_tensor_wrappers_cleared;\n" + " }\n" " private:\n" " // TensorWrappers\n" "%s\n" + " bool is_tensor_wrappers_cleared = false;\n" + "\n" " // Attribute Map\n" "%s\n" "};"; @@ -2154,6 +2165,7 @@ static std::string GenerateGradNodeHeaderContents( std::string set_tensor_wrappers_str = ""; std::string tensor_wrapper_members_str = ""; + std::string clear_tensor_wrappers_str = ""; for (const auto& iter : op_base_infos) { const std::map& grad_ins_fwd_slotname_map = iter.GetGradInsFwdSlotnameMap(); @@ -2185,6 +2197,13 @@ static std::string GenerateGradNodeHeaderContents( SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, struct_tensor_wrapper_name); + const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = + "for (auto tw: %s) {\n" + " tw.clear();\n" + " }\n"; + clear_tensor_wrappers_str += paddle::string::Sprintf( + CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name); + } else { const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = "const paddle::experimental::Tensor& %s"; @@ -2197,10 +2216,14 @@ static std::string GenerateGradNodeHeaderContents( TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = - "%s = egr::TensorWrapper(%s, %s /*full_reserved*/);"; + "%s = egr::TensorWrapper(%s, %s /*full_reserved*/);\n"; tensor_wrapper_body_str = paddle::string::Sprintf( SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, tensor_wrapper_name, full_reserved_str); + + const char* CLEAR_TENSOR_WRAPPER_TEMPLATE = " %s.clear();\n"; + clear_tensor_wrappers_str += paddle::string::Sprintf( + CLEAR_TENSOR_WRAPPER_TEMPLATE, struct_tensor_wrapper_name); } std::string full_reserved_signature_str = "bool full_reserved"; const char* SET_TENSOR_WRAPPER_TEMPLATE = @@ -2215,8 +2238,8 @@ static std::string GenerateGradNodeHeaderContents( std::string grad_node_str = paddle::string::Sprintf( GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type, - op_type, op_type, set_tensor_wrappers_str, set_attr_map_str, - tensor_wrapper_members_str, attr_members_str); + op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str, + set_attr_map_str, tensor_wrapper_members_str, attr_members_str); return grad_node_str; } diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index d2d699e154f91d21e94e9d2dac3f703069d041e5..4c1e5b00cbaf6fd0688663c9dac756832e44dc4a 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -478,6 +478,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, # SetTensorWrapper Methods & TensorWrapper Members set_tensor_wrapper_methods_str = "" tensor_wrapper_members_str = "" + clear_tensor_wrapper_str = "" for tname, (ttype, is_fwd_input, _) in backward_fwd_input_map.items(): if tname in no_need_buffer_set: no_need_buffer = "true" @@ -499,6 +500,13 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, """ tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format( tensor_wrapper_name) + + CLEAR_TENSOR_WRAPPERS_TEMPLATE = """ + {}.clear(); +""" + clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format( + tensor_wrapper_name) + else: assert IsVectorTensorType(ttype) SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """ @@ -516,6 +524,15 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, """ tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format( tensor_wrapper_name) + + CLEAR_TENSOR_WRAPPERS_TEMPLATE = """ + for (auto tw: {}) { + tw.clear(); + }; +""" + clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPERS_TEMPLATE.format( + tensor_wrapper_name) + # End: SetTensorWrapper Methods & TensorWrapper Members # SetAttributes & Attribute Members @@ -524,7 +541,7 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, for aname, atype, default_val, _ in backward_attrs_list: saved_attr_name = GetSavedName(aname) SET_ATTR_METHOD_TEMPLATE = """ - void SetAttribute{}({} {}) {{ + void SetAttribute{}({} {}) {{ {} = {}; }} """ @@ -555,25 +572,37 @@ class {} : public egr::GradNodeBase {{ ~{}() override = default; virtual std::vector> operator()( - const std::vector>& grads) override; + const std::vector>& grads, bool create_graph = false) override; std::string name() override {{ return \" {} \"; }} + + void ClearTensorWrappers() override {{ + {} + is_tensor_wrappers_cleared = true; + }} + // SetTensorWrapperX, SetTensorWrapperY, ... {} // SetAttributes {} + + bool IsTensorWrappersCleared() override {{ + return is_tensor_wrappers_cleared; + }} private: // TensorWrappers {} + bool is_tensor_wrappers_cleared = false; + // Attributes {} }}; """ node_declaration_str = NODE_DECLARATION_TEMPLATE.format( grad_node_name, grad_node_name, grad_node_name, grad_node_name, - grad_node_name, set_tensor_wrapper_methods_str, - set_attribute_methods_str, tensor_wrapper_members_str, - attribute_members_str) + grad_node_name, clear_tensor_wrapper_str, + set_tensor_wrapper_methods_str, set_attribute_methods_str, + tensor_wrapper_members_str, attribute_members_str) return node_declaration_str @@ -637,7 +666,7 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, grad_api_namespace = f"paddle::experimental" FUNCTION_TEMPLATE = """ -std::vector> {}::operator()(const std::vector>& grads) {{ +std::vector> {}::operator()(const std::vector>& grads, bool create_graph) {{ // Call grad_api function auto grad_api_returns = {}::{}({}); {} diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 1987d024d8f3e34121f54962c45f0f8c1e91b723..f2d5f338bd4af6ce1f9858d35485a0b23dc5f61c 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -39,12 +39,21 @@ std::unordered_map getInDegreeMap( // Copy nodes std::queue queue = init_queue; std::unordered_set visited; + size_t potential_startup_ops_cnt = queue.size(); + size_t cnt = 0; // Visit each node exactly once in any order while (!queue.empty()) { GradNodeBase* node = queue.front(); queue.pop(); + if (cnt < potential_startup_ops_cnt) { + if (!node_in_degree_map.count(node)) { + node_in_degree_map[node] = 0; + } + cnt += 1; + } + if (visited.count(node)) { continue; } @@ -76,23 +85,248 @@ std::unordered_map getInDegreeMap( return node_in_degree_map; } -void RunBackward(const std::vector& tensors, - const std::vector& grad_tensors, - bool retain_graph) { - paddle::platform::RecordEvent backward_record_event( - "backward", paddle::platform::TracerEventType::Operator, 1); +// Remove some nodes those doesn't need to be +// stored in potential_stop_nodes、potential_startup_nodes +void UpdateGraphInfo( + std::unordered_map* + target_nodes_inputmeta_map, + std::unordered_map>* + depending_nodes, + std::unordered_set* potential_stop_nodes, + std::unordered_set* potential_startup_nodes) { + // Updated potential_sotp_nodes by depending_nodes, + // make sure the path from root to target_node is ok + std::unordered_set _startup_ops; + VLOG(6) << "Running in UpdateGraphInfo"; + std::queue queue; + for (auto& target_nodes_inputmeta_pair : *target_nodes_inputmeta_map) { + queue.emplace(target_nodes_inputmeta_pair.first); + } + + while (!queue.empty()) { + auto* target_node = queue.front(); + queue.pop(); + if (!(*depending_nodes)[target_node].empty()) { + auto precedding_nodes = (*depending_nodes)[target_node]; + for (auto pre_nodes : precedding_nodes) { + queue.emplace(pre_nodes); + if (potential_stop_nodes->find(pre_nodes) != + potential_stop_nodes->end()) { + potential_stop_nodes->erase(pre_nodes); + } + } + } else { // startup_ops have no precedding nodes + VLOG(6) << "Emplace _startup_ops"; + _startup_ops.emplace(target_node); + } + } + // Purify potential_startup_nodes again, remove some + // potential startup_nodes that unreach to input target nodes + if (!_startup_ops.empty()) { + std::unordered_set potential_startup_nodes_to_be_erased; + for (auto node : *potential_startup_nodes) { + if (_startup_ops.count(node) == 0) { + VLOG(6) << "Set up potential_startup_nodes_to_be_erased"; + potential_startup_nodes_to_be_erased.emplace(node); + } + } + if (!potential_startup_nodes_to_be_erased.empty()) { + for (auto node : potential_startup_nodes_to_be_erased) { + VLOG(6) << "Erase nodes in potential_startup_nodes_to_be_erased"; + potential_startup_nodes->erase(node); + } + } + } +} + +// Get Graph Info Betweent input target gradnode and outputs, +// record depending_nodes、 potential_stop_nodes、potential_startup_nodes +void GetGraphInfoBetweenTargets( + const std::queue& init_queue, + std::unordered_map* + input_target_nodes_inputmeta_map, + std::unordered_map>* + depending_nodes, + std::unordered_set* potential_stop_nodes, + std::unordered_set* potential_startup_nodes) { + if (input_target_nodes_inputmeta_map->empty()) return; + + VLOG(6) << "Runing In GetGraphInfoBetweenTargets"; + + // Calculate in_degree for each node + std::unordered_map node_in_degree_map; + + // Copy nodes + std::queue queue = init_queue; + std::unordered_set visited; + + // Visit each node exactly once in any order + while (!queue.empty()) { + GradNodeBase* node = queue.front(); + queue.pop(); + + if (visited.count(node)) { + continue; + } + visited.insert(node); + + // Check node is target_nodes or not, if node is not target_node, + // all the next_node will be marked in potential_stop_nodes + bool is_potential_stop_nodes = + input_target_nodes_inputmeta_map->count(node); + + // Find and append next nodes + const std::vector>& edges = node->GetEdges(); + for (const auto& edge_list : edges) { + for (const Edge& edge : edge_list) { + GradNodeBase* next_node = edge.GetMutableGradNode().get(); + + // Next node could be nullptr if it is leaf tensor with no + // AccumulationNode attached + // Or it could also originated from dispensable inputs + if (!next_node) continue; + + // if node not in input_target_nodes, + // all the next_nodes of current node will be inserted to + // potential_stop_node + if (is_potential_stop_nodes) { + potential_stop_nodes->emplace(next_node); + } + + // Update in_degree + if (!node_in_degree_map.count(next_node)) + node_in_degree_map[next_node] = 0; + node_in_degree_map[next_node]++; + // Record depending relationship + (*depending_nodes)[next_node].emplace(node); + queue.push(next_node); + } + } + } + // Update Graph Info, remove some stop_node in potential_stop_nodes + UpdateGraphInfo(input_target_nodes_inputmeta_map, depending_nodes, + potential_stop_nodes, potential_startup_nodes); +} + +void GetTargetNodesInfo(const std::vector& inputs, + std::unordered_map* + target_nodes_inputmeta_map) { + VLOG(6) << "Running in GetTargetNodesInfo"; + if (!inputs.empty()) { + VLOG(6) << "Inputs are not empty"; + size_t num_inputs = inputs.size(); + for (size_t i = 0; i < num_inputs; i++) { + AutogradMeta* auto_grad_meta = + EagerUtils::unsafe_autograd_meta(inputs[i]); + auto target_node = auto_grad_meta->GetMutableGradNode().get(); + + PADDLE_ENFORCE_NOT_NULL(target_node, + paddle::platform::errors::Fatal( + "There is no grad op for input:%d or it's" + "stop_gradient=True", + i)); + (*target_nodes_inputmeta_map)[target_node] = auto_grad_meta; + } + } +} + +std::vector GetResults( + const std::vector& inputs, + std::unordered_map* + results_map, + bool allow_unused, bool create_graph) { + VLOG(6) << "Running in GetResults"; + if (inputs.empty()) return {}; + + std::vector results; + results.reserve(inputs.size()); + + for (size_t i = 0; i < inputs.size(); ++i) { + auto& input = inputs[i]; + AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input); + auto target_node = auto_grad_meta->GetMutableGradNode().get(); + + auto iter = results_map->find(target_node); + if (iter != results_map->end()) { + // set StopGradient = !create_graph + AutogradMeta* tensor_auto_grad_meta = + EagerUtils::autograd_meta(&(iter->second)); + tensor_auto_grad_meta->SetStopGradient(!create_graph); + results.emplace_back(iter->second); + } else { + PADDLE_ENFORCE_EQ(allow_unused, true, + paddle::platform::errors::InvalidArgument( + "The %d-th input does not appear in the backward " + "graph. Please check the input variable or set " + "allow_unused=True to get None result.", + i)); + results.emplace_back(); + } + } + return results; +} + +// Enforce GradNode has TensorWrappers as Input +void EnforceGradNodeHasInput(GradNodeBase* node) { + VLOG(6) << "Running in EnforceGradNodeHasInput"; + PADDLE_ENFORCE_NE( + node->IsTensorWrappersCleared(), true, + paddle::platform::errors::Fatal( + "The TensorWrappers of %s do not exist. This may be because:\n" + "You calculate backward twice for the same subgraph without " + "setting retain_graph=True. Please set retain_graph=True in the " + "first backward/grad call.\n", + node->name())); +} + +// Purify potential_startup_nodes, remove nodes those are the same as +// input_target_nodes +void PurifyPotentialStartUpNodes( + std::unordered_set* potential_startup_nodes, + std::unordered_map* + input_target_nodes_inputmeta_map) { + VLOG(6) << "Running in PurifyPotentialStartUpNodes"; + if (input_target_nodes_inputmeta_map->empty()) return; + std::unordered_set potential_startup_nodes_to_be_erased; + for (auto startup_op : *potential_startup_nodes) { + auto iter = input_target_nodes_inputmeta_map->find(startup_op); + if (iter != input_target_nodes_inputmeta_map->end()) { + potential_startup_nodes_to_be_erased.emplace(iter->first); + } + } + if (!potential_startup_nodes_to_be_erased.empty()) { + for (auto nodes : potential_startup_nodes_to_be_erased) { + potential_startup_nodes->erase(nodes); + } + } +} + +std::vector RunBackward( + const std::vector& tensors, // output + const std::vector& grad_tensors, + bool retain_graph, bool create_graph = false, + const std::vector& inputs = {}, + bool allow_unused = false, + const std::vector& no_grad_vars = {}) { VLOG(6) << "Start Backward"; // *Gradient Hook should happen at node-level // *Inplace version check should perform at node-level // *Cross-batch accumulation happens at forward pass + std::unordered_map + no_grad_var_nodes_inputmeta_map; + // Get no_grad_vars's GradNodes and InputMeta Info + GetTargetNodesInfo(no_grad_vars, &no_grad_var_nodes_inputmeta_map); + /* --- Initialization --- */ // 1. Init queue with starting nodes // 2. Prepare initial input buffers std::queue queue; std::unordered_map> node_input_buffers_dict; + std::unordered_set potential_startup_nodes; for (size_t i = 0; i < tensors.size(); i++) { const paddle::experimental::Tensor& tensor = tensors[i]; @@ -132,8 +366,17 @@ void RunBackward(const std::vector& tensors, "size = 0 or same size as tensors")); // Feed given tensor if it's provided VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor"; - node_input_buffers_dict[grad_node]->add( - input_info.first, input_info.second, grad_tensors[i]); + + if (grad_tensors[i].is_initialized()) { + // Deep copy + paddle::experimental::Tensor tmp_tensor; + tmp_tensor.copy_(grad_tensors[i], true); + node_input_buffers_dict[grad_node]->add(input_info.first, + input_info.second, tmp_tensor); + } else { + node_input_buffers_dict[grad_node]->add( + input_info.first, input_info.second, grad_tensors[i]); + } } else { VLOG(6) << "Fill grad input tensor " << i << " with 1.0"; @@ -146,8 +389,9 @@ void RunBackward(const std::vector& tensors, input_info.first, input_info.second, tensor, true /*fill_one=true*/); } - // Prepare queue + // Prepare queue, potential startup_nodes queue.push(grad_node); + potential_startup_nodes.emplace(grad_node); } VLOG(6) << "Update In degree Map for backward"; @@ -155,25 +399,74 @@ void RunBackward(const std::vector& tensors, std::unordered_map node_in_degree_map = getInDegreeMap(queue); + // Get input's GradNodes and InputMeta Info + std::unordered_map + input_target_nodes_inputmeta_map; + GetTargetNodesInfo(inputs, &input_target_nodes_inputmeta_map); + + // Purify potential_startup_ops, remove those nodes that are the same as + // input_target_nodes + PurifyPotentialStartUpNodes(&potential_startup_nodes, + &input_target_nodes_inputmeta_map); + + // Get Graph Info Betweent input target gradnode and outputs + // Record the depending_nodes and potential_stop_nodes + std::unordered_map /* father node */> + depending_nodes; + std::unordered_set potential_stop_nodes; + // std::unordered_set startup_ops; + + GetGraphInfoBetweenTargets(queue, &input_target_nodes_inputmeta_map, + &depending_nodes, &potential_stop_nodes, + &potential_startup_nodes); + + // ready_queue store all startup nodes + std::queue ready_queue; + // startup op's indegree should be 0 + for (auto node : potential_startup_nodes) { + if (node_in_degree_map[node] == 0) { + ready_queue.emplace(node); + } + } + + VLOG(1) << " startup_ops' size is :" << ready_queue.size(); + + std::unordered_map results_map; + + // read_queue is empty only when 1.input equals to output. 2.input can not + // reach to output. + if (ready_queue.size() == 0) { + for (auto input_target_node : input_target_nodes_inputmeta_map) { + // out rank_info of forward op + auto rank_info = input_target_node.second->OutRankInfo(); + if (node_input_buffers_dict[input_target_node.first]) { + auto& target_result = + node_input_buffers_dict[input_target_node.first] + ->Buffers()[rank_info.first][rank_info.second]; + // save the target result + results_map[input_target_node.first] = target_result; + } + } + } + /* --- Topological Visit --- */ // 1. Pop queue // 2. Run node + // |- Check and capture target result // |- node(grads) // |- Prepare for next node // 3. Update queue VLOG(6) << "Run Backward"; - while (!queue.empty()) { - GradNodeBase* node = queue.front(); + while (!ready_queue.empty()) { + GradNodeBase* node = ready_queue.front(); + VLOG(6) << "Running GradNode:" << node->name(); + ready_queue.pop(); paddle::platform::RecordEvent node_record_event( std::string(typeid(*node).name()) + " grad_node", paddle::platform::TracerEventType::Operator, 1); - if (queue.size() > 1 && node_in_degree_map[node] != 0) { - queue.pop(); - continue; - } - queue.pop(); // Run node: This is where Hook happens PADDLE_ENFORCE( node_input_buffers_dict.count(node), @@ -184,10 +477,45 @@ void RunBackward(const std::vector& tensors, std::unique_ptr node_input_buffer = std::move(node_input_buffers_dict[node]); + // get target grad_var from node_input_buffer by inputmeta + if (input_target_nodes_inputmeta_map.find(node) != + input_target_nodes_inputmeta_map.end()) { + VLOG(6) << "Get target result by by inputmeta"; + // out rank_info of forward op + auto rank_info = input_target_nodes_inputmeta_map[node]->OutRankInfo(); + // rank_info is a pair, first means slot_id, second means rank. + auto& target_result = + node_input_buffer->Buffers()[rank_info.first][rank_info.second]; + // save the target result + results_map[node] = target_result; + } + + // no_grad_vars + if (no_grad_var_nodes_inputmeta_map.find(node) != + no_grad_var_nodes_inputmeta_map.end()) { + VLOG(6) << "Change the input buffer[slot][rank] by Zeros"; + auto rank_info = no_grad_var_nodes_inputmeta_map[node]->OutRankInfo(); + node_input_buffer->SetBufferSlotRankZeros(rank_info.first, + rank_info.second); + } + + VLOG(6) << "Running GradNode:" << node->name(); + + // check input + EnforceGradNodeHasInput(node); + VLOG(6) << "Run Backward Kernel with GradTensorHolder"; // Run Pre Backward Node and get outputs std::vector> grad_output_tensors = - (*node)(node_input_buffer->Buffers()); + (*node)(node_input_buffer->Buffers(), create_graph); + + // retain_grad or not + if (!retain_graph) { + VLOG(6) + << "retain_graph is false, need to clear the TensorWrapper of nodes."; + node->ClearTensorWrappers(); + } + // TODO(jiabin): Should we erase it or find a more efficient way. node_input_buffers_dict.erase(node); @@ -252,18 +580,44 @@ void RunBackward(const std::vector& tensors, // Update queue node_in_degree_map[next_node]--; + PADDLE_ENFORCE( node_in_degree_map[next_node] >= 0, paddle::platform::errors::Fatal( "Detected in-degree value smaller than zero. For Node: %s" "Node's in-degree cannot be negative", next_node->name())); - if (node_in_degree_map[next_node] == 0) { - queue.emplace(std::move(next_node)); + + bool is_potential_stop_node = potential_stop_nodes.count(next_node); + + if (node_in_degree_map[next_node] == 0 && !is_potential_stop_node) { + ready_queue.emplace(std::move(next_node)); } } } } + + return GetResults(inputs, &results_map, allow_unused, create_graph); } +void Backward( + const std::vector& tensors, // output + const std::vector& grad_tensors, + bool retain_graph) { + VLOG(6) << "Run in Backward"; + paddle::platform::RecordEvent backward_record_event( + "backward", paddle::platform::TracerEventType::Operator, 1); + RunBackward(tensors, grad_tensors, retain_graph); +} + +std::vector Grad( + const std::vector& tensors, // output + const std::vector& inputs, + const std::vector& grad_tensors, + bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused, + const std::vector& no_grad_vars) { + VLOG(6) << "Run in Grad"; + return RunBackward(tensors, grad_tensors, retain_graph, create_graph, inputs, + allow_unused, no_grad_vars); +} } // namespace egr diff --git a/paddle/fluid/eager/backward.h b/paddle/fluid/eager/backward.h index 2856d9fb87f34b1066bb59eb38bcaee786d2a260..bebe664838e6c1f98219ceee6e6733b49c319b3c 100644 --- a/paddle/fluid/eager/backward.h +++ b/paddle/fluid/eager/backward.h @@ -19,12 +19,20 @@ namespace egr { -// run_backward(): +// Backward(): // tensors corresponds to those lived in the backward graph // each grad_tensors[i] keeps the value for its corresponding tensors[i] -void RunBackward(const std::vector &tensors, - const std::vector &grad_tensors, - bool retain_graph = false); +void Backward(const std::vector& tensors, + const std::vector& grad_tensors, + bool retain_graph = false); + +std::vector Grad( + const std::vector& tensors, + const std::vector& inputs, + const std::vector& grad_tensors = {}, + bool retain_graph = false, bool create_graph = false, + bool only_inputs = false, bool allow_unused = false, + const std::vector& no_grad_vars = {}); // Reserved for gradient() diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 48ac8c8358afd68cee9d22b8ea0a4e8fd7c3c92e..72af1cc4b068679e72ae6bdc5e09fab8f56bac04 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -20,8 +20,8 @@ namespace egr { std::vector> RunCustomOpNode:: -operator()( - const std::vector>& grads) { +operator()(const std::vector>& grads, + bool create_graph) { paddle::CustomOpKernelContext ctx; auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs( egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.h b/paddle/fluid/eager/custom_operator/custom_operator_node.h index e5ddef9c062149282d790a5fd6bf31b25a20cf5a..6ece2658575c795856438904c2716d61f0985879 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.h +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.h @@ -37,8 +37,8 @@ class RunCustomOpNode : public GradNodeBase { // Functor: perform backward computations virtual std::vector> operator()( - const std::vector>& grads) - override; + const std::vector>& grads, + bool create_graph) override; std::string name() { return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_); @@ -62,6 +62,12 @@ class RunCustomOpNode : public GradNodeBase { return res; } + void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } + bool IsTensorWrappersCleared() override { + VLOG(6) << "Do nothing here now"; + return false; + } + void SetAttrs(const std::vector& attr) { attrs_ = attr; } public: diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 16513f05e0777a8e57f54c925d68867dda656612..168e1bcca77ca85eb6fa90a23350d1f62f63dc8e 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -95,8 +95,12 @@ class GradNodeBase { * is better choice to fit this format. * **/ virtual std::vector> operator()( - const std::vector>& grads) = 0; + const std::vector>& grads, + bool create_graph = false) = 0; + virtual void ClearTensorWrappers() = 0; + + virtual bool IsTensorWrappersCleared() = 0; /** * AddEdges is designed to set input tensors' backward Node as current * node's Edges. diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index 69fc7df2f1420382735cf59fbe85f7e2207d0f77..163d25e85ce8c085087331c6e3273075aed5e5f4 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -21,6 +21,11 @@ namespace egr { +void GradTensorHolder::SetBufferSlotRankZeros(size_t slot_id, size_t rank) { + buffer_[slot_id][rank] = + paddle::experimental::zeros_like(buffer_[slot_id][rank]); +} + void GradTensorHolder::add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t, bool fill_one) { diff --git a/paddle/fluid/eager/grad_tensor_holder.h b/paddle/fluid/eager/grad_tensor_holder.h index d66a81fe8285980bad4159d5414985dc9c744549..9059b403607461cc980a58d345fe1542aa4b1903 100644 --- a/paddle/fluid/eager/grad_tensor_holder.h +++ b/paddle/fluid/eager/grad_tensor_holder.h @@ -56,6 +56,8 @@ class GradTensorHolder { return buffer_; } + void SetBufferSlotRankZeros(size_t slot_id, size_t rank); + private: std::vector> buffer_; }; diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index 31aaa93c41643f565836c536d7001c01d2a0826d..0e11444b81526de1904b72fc983814314d834a45 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -98,6 +98,8 @@ class TensorWrapper { } } + void clear() { intermidiate_tensor_.reset(); } + private: bool full_reserved_ = false; std::pair out_rank_info_; diff --git a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h index 535c93ac53b1751d9634476e47f32dc0cbe22708..0b167203735d65683b0f978fa34fe7f457aae4f2 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h +++ b/paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h @@ -32,8 +32,8 @@ class GradTestNode : public egr::GradNodeBase { GradTestNode() : GradNodeBase() { val_ = 1.0; } std::string name() override { return "GradTestNode"; } std::vector> operator()( - const std::vector>& grads) - override { + const std::vector>& grads, + bool create_graph = false) override { val_ = std::dynamic_pointer_cast(grads[0][0].impl()) ->data()[0]; phi::DenseTensorMeta meta = @@ -49,6 +49,11 @@ class GradTestNode : public egr::GradNodeBase { std::vector> res = {{et1}}; return res; } + void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } + bool IsTensorWrappersCleared() override { + VLOG(6) << "Do nothing here now"; + return false; + } float val_; }; } // namespace eager_test diff --git a/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc b/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc index 769bd7f687f4584d44bbfa30b73611a3128289bf..887ea3e3acfd50a15206f3e84ab45e16707f80af 100644 --- a/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc +++ b/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc @@ -58,7 +58,7 @@ void benchmark_eager_scale(const paddle::experimental::Tensor& tensor, } std::vector target_tensors = {input_tensor}; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); if (accuracy_check) { // Examine Forward Grad (w.r.t max_num_runs = 10) @@ -80,7 +80,7 @@ void benchmark_eager_matmul(const paddle::experimental::Tensor& X, } std::vector target_tensors = {input_tensor0}; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); if (accuracy_check) { // Examine Forward Grad (w.r.t max_num_runs = 2) @@ -106,7 +106,7 @@ void benchmark_eager_intermediate_matmul(const paddle::experimental::Tensor& X, } std::vector target_tensors = {input_tensor0}; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); if (accuracy_check) { // Examine Forward Grad (w.r.t max_num_runs = 2) @@ -137,7 +137,7 @@ void benchmark_eager_intermediate_mlp( reduce_sum_dygraph_function(input0, {{"reduce_all", true}}); std::vector target_tensors = {Out}; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); if (accuracy_check) { std::unordered_map result = diff --git a/paddle/fluid/eager/tests/task_tests/CMakeLists.txt b/paddle/fluid/eager/tests/task_tests/CMakeLists.txt index c65ad4641cf2206cc0f97d91f1fb24e50b7b63cd..52dba6b9218c7be8a29ae1aff619facd25a6f3b6 100644 --- a/paddle/fluid/eager/tests/task_tests/CMakeLists.txt +++ b/paddle/fluid/eager/tests/task_tests/CMakeLists.txt @@ -5,6 +5,7 @@ cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_ cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) +cc_test(test_egr_task_grad SRCS grad_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) cc_test(test_egr_task_hook_intermidiate SRCS hook_test_intermidiate.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} dygraph_node) diff --git a/paddle/fluid/eager/tests/task_tests/backward_test.cc b/paddle/fluid/eager/tests/task_tests/backward_test.cc index 0c894ed267fcdd08d44d4df08bfaf0554874aebf..87f8f6eca1f88fe9a54583ee19586dd75c7e231e 100644 --- a/paddle/fluid/eager/tests/task_tests/backward_test.cc +++ b/paddle/fluid/eager/tests/task_tests/backward_test.cc @@ -33,6 +33,7 @@ #include "paddle/phi/core/kernel_registry.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); namespace egr { @@ -79,7 +80,7 @@ TEST(Backward, SingleNodeEmptyGrad) { } std::vector outs = {target_tensor}; // Run Backward - RunBackward(outs, {}); + Backward(outs, {}); // Check Output Value eager_test::CompareGradTensorWithValue(leaf_tensor, 5.0); @@ -138,7 +139,7 @@ TEST(Backward, SingleNodeCustomGrad) { } // Run Backward - RunBackward(target_tensors, grad_tensors); + Backward(target_tensors, grad_tensors); // Check Output Value eager_test::CompareGradTensorWithValue(leaf_tensor, 50.0); @@ -211,7 +212,7 @@ TEST(Backward, LinearNodes) { } // Use Empty Grad Tensor - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); // Check Output Value eager_test::CompareGradTensorWithValue(leaf_tensor, 50.0); @@ -315,7 +316,7 @@ TEST(Backward, WithAccumulation) { node2_ptr->AddEdges(&res2, 0); } - RunBackward(target_tensors, grad_tensors); + Backward(target_tensors, grad_tensors); eager_test::CompareGradTensorWithValue(leaf_tensor, 2500.0); } diff --git a/paddle/fluid/eager/tests/task_tests/cross_batch_accumulation_test.cc b/paddle/fluid/eager/tests/task_tests/cross_batch_accumulation_test.cc index 36594f1aac8cdb131bb77f1396dca19a0c2e8cc0..8b0759c17ed3712079e8954df60e35afaaf02a9e 100644 --- a/paddle/fluid/eager/tests/task_tests/cross_batch_accumulation_test.cc +++ b/paddle/fluid/eager/tests/task_tests/cross_batch_accumulation_test.cc @@ -71,12 +71,12 @@ TEST(CrossBatchAccumulation, SingleScaleNode) { std::vector res = {meta}; scale_node_ptr->AddEdges(&res, 0); - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(target_tensor, 1.0); eager_test::CompareGradTensorWithValue(leaf_tensor, 5.0); - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(target_tensor, 1.0); eager_test::CompareGradTensorWithValue(leaf_tensor, 10.0); diff --git a/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc b/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc index f7fa642ea8dd17d20816e74c9bfb4cd92b184b4a..882695e98d109e09340223e21322a02d1b48c6ea 100644 --- a/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc +++ b/paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc @@ -86,7 +86,7 @@ TEST(FwdBwdJoint, SingleNode) { std::vector outs = {out}; // 4. Run Backward - RunBackward(outs, {}); + Backward(outs, {}); VLOG(7) << "Target Grad is: " << std::static_pointer_cast( @@ -137,7 +137,7 @@ TEST(FwdBwdJoint, LinearNodes) { std::vector outs = {out1}; // 4. Run Backward - RunBackward(outs, {}); + Backward(outs, {}); // Examine Backward Grad eager_test::CompareGradTensorWithValue(tensor, 10.0); @@ -203,7 +203,7 @@ TEST(FwdBwdJoint, BranchedNodes) { // 4. Run Backward std::vector outs = {out1, out2}; - RunBackward(outs, {}); + Backward(outs, {}); // Examine Backward Grad eager_test::CompareGradTensorWithValue(tensor, 30.0); @@ -260,7 +260,7 @@ TEST(FwdBwdJoint, GradientHook) { // 4. Run Backward std::vector outs = {out1, out2}; - RunBackward(outs, {}); + Backward(outs, {}); // Examine Backward Grad // leaf grad @@ -318,13 +318,13 @@ TEST(FwdBwdJoint, CrossBatchAccumulation) { // 4. Run Backward std::vector outs = {out1, out2}; - RunBackward(outs, {}); + Backward(outs, {}); // Examine Backward Grad eager_test::CompareGradTensorWithValue(tensor, 30.0); // Cross Batch Accumulation - RunBackward(outs, {}); + Backward(outs, {}); // Examine Backward Grad eager_test::CompareGradTensorWithValue(tensor, 60.0); @@ -356,7 +356,7 @@ TEST(FwdBwdJoint, SingleNodeCUDA) { std::vector outs = {out}; // 4. Run Backward - RunBackward(outs, {}); + Backward(outs, {}); // Examine Backward Grad eager_test::CompareGradTensorWithValue(tensor, 2.0); @@ -412,7 +412,7 @@ TEST(FwdBwdJoint, BranchedNodesCUDA) { // TODO(jiabin): fix this with add functor // 4. Run Backward std::vector outs = {out1, out2}; - RunBackward(outs, {}); + Backward(outs, {}); // Examine Backward Grad eager_test::CompareGradTensorWithValue(tensor, 30.0); diff --git a/paddle/fluid/eager/tests/task_tests/generated_test.cc b/paddle/fluid/eager/tests/task_tests/generated_test.cc index 2a5ad53204a6201149bec0b3dac0fa3baf441f2e..68820443a2d5a68c73c0d5ebb855519fddbbf3d2 100644 --- a/paddle/fluid/eager/tests/task_tests/generated_test.cc +++ b/paddle/fluid/eager/tests/task_tests/generated_test.cc @@ -57,7 +57,7 @@ TEST(Generated, Sigmoid) { std::vector target_tensors = {output_tensor}; VLOG(6) << "Runing Backward"; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); VLOG(6) << "Finish Backward"; eager_test::CompareGradTensorWithValue(tensor, 0.25); @@ -89,7 +89,7 @@ TEST(Generated, Matmul_v2) { eager_test::CompareTensorWithValue(output_tensor, 96); std::vector target_tensors = {output_tensor}; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(X, 2.0 * 20); eager_test::CompareGradTensorWithValue(Y, 3.0 * 4); @@ -120,7 +120,7 @@ TEST(Generated, ElementwiseAdd) { eager_test::CompareTensorWithValue(output_tensor, 5); std::vector target_tensors = {output_tensor}; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(X, 1.0); eager_test::CompareGradTensorWithValue(Y, 1.0); diff --git a/paddle/fluid/eager/tests/task_tests/grad_test.cc b/paddle/fluid/eager/tests/task_tests/grad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b03799c48659c579938df6efc0f7cf57bbc0bec --- /dev/null +++ b/paddle/fluid/eager/tests/task_tests/grad_test.cc @@ -0,0 +1,339 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/fluid/eager/accumulation/accumulation_node.h" +#include "paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h" +#include "paddle/fluid/eager/api/utils/tensor_utils.h" +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/backward.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/tests/test_utils.h" + +#include "paddle/fluid/eager/api/all.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" + +PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); +namespace egr { + +TEST(Grad, SingleNodeEmptyGrad) { + // Prepare Device Contexts + eager_test::InitEnv(paddle::platform::CPUPlace()); + + // Prepare Inputs + paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); + + // Create Target Tensor (output) + paddle::experimental::Tensor output_tensor = + egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + + // Create input tensor + const paddle::experimental::Tensor leaf_tensor = + egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/); + + { + // Create Scale Node + auto node0_ptr = std::make_shared(1, 1); + node0_ptr->SetAttributes_scale(5.0 /*scale*/); + + // Set grad in/out meta + node0_ptr->SetDefaultGradInOutMeta(); + + // Output_tensor set GradNode、OutRank、StopGradient propertis + AutogradMeta* auto_grad_meta = EagerUtils::autograd_meta(&output_tensor); + auto_grad_meta->SetGradNode( + std::dynamic_pointer_cast(node0_ptr)); + auto_grad_meta->SetSingleOutRankWithSlot(0, 0); + auto_grad_meta->SetStopGradient(false); + + // Get autograd_meta from input tensor + AutogradMeta* auto_grad_meta1 = + EagerUtils::unsafe_autograd_meta(leaf_tensor); + + // Connect Tensor and AccumulationNode via AutoGradMeta + auto acc_node_ptr = + std::make_shared(auto_grad_meta1); + + // input tensor set GradNode、OutRank、StopGradient propertis + auto_grad_meta1->SetGradNode( + std::dynamic_pointer_cast(acc_node_ptr)); + auto_grad_meta1->SetSingleOutRankWithSlot(0, 0); + auto_grad_meta1->SetStopGradient(false); + + // grad_node Add Edges + std::vector res = {auto_grad_meta1}; + node0_ptr->AddEdges(&res, 0); + } + std::vector outs = {output_tensor}; + + // Run Grad + auto result = Grad(outs, {leaf_tensor}, {}); + // Check Output Value + eager_test::CompareTensorWithValue(result[0], 5.0); +} + +TEST(Grad, SingleNodeCustomGrad) { + // Prepare Device Contexts + eager_test::InitEnv(paddle::platform::CPUPlace()); + + // Prepare Inputs + std::vector target_tensors; + paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); + + // Create Target Tensor + paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + target_tensors.emplace_back(std::move(tensor)); + + std::vector grad_tensors; + // Create Grad Tensor + paddle::experimental::Tensor grad_tensor = + egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/); + grad_tensors.emplace_back(std::move(grad_tensor)); + + paddle::experimental::Tensor leaf_tensor = + egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/); + + { + // Create Scale Node + auto node0_ptr = std::make_shared(1, 1); + node0_ptr->SetAttributes_scale(5.0 /*scale*/); + + // Set grad in/out meta + node0_ptr->SetDefaultGradInOutMeta(); + + // Connect Tensor and Node via AutoGradMeta + AutogradMeta* auto_grad_meta = + EagerUtils::autograd_meta(&(target_tensors[0])); + auto_grad_meta->SetGradNode( + std::dynamic_pointer_cast(node0_ptr)); + auto_grad_meta->SetSingleOutRankWithSlot(0, 0); + auto_grad_meta->SetStopGradient(false); + + AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor); + // Connect Tensor and AccumulationNode via AutoGradMeta + auto acc_node_ptr = + std::make_shared(auto_grad_meta1); + + auto_grad_meta1->SetGradNode( + std::dynamic_pointer_cast(acc_node_ptr)); + auto_grad_meta1->SetSingleOutRankWithSlot(0, 0); + auto_grad_meta1->SetStopGradient(false); + std::vector res = {auto_grad_meta1}; + node0_ptr->AddEdges(&res, 0); + } + + auto result = Grad(target_tensors, {leaf_tensor}, grad_tensors); + + // Check Output Value + eager_test::CompareTensorWithValue(result[0], 50.0); +} + +/* +Node1 + | +Node0 + | + { } // empty grad tensor +*/ +TEST(Grad, LinearNodes) { + // Prepare Device Contexts + eager_test::InitEnv(paddle::platform::CPUPlace()); + + // Prepare Target Tensor + std::vector target_tensors; + paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); + + // Create Target Tensor + paddle::experimental::Tensor tensor = egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + target_tensors.emplace_back(std::move(tensor)); + + paddle::experimental::Tensor leaf_tensor = + egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, true /*is_leaf*/); + { + // Create Node0 + auto node0_ptr = std::make_shared(1, 1); + node0_ptr->SetAttributes_scale(5.0 /*scale*/); + + // Set grad in/out meta for node0 + node0_ptr->SetDefaultGradInOutMeta(); + + // Create Node1 + auto node1_ptr = std::make_shared(1, 1); + node1_ptr->SetAttributes_scale(10.0 /*scale*/); + + // Set grad in/out meta for node1 + node1_ptr->SetDefaultGradInOutMeta(); + + // Connect Input Tensor and Node0 via AutoGradMeta + AutogradMeta* auto_grad_meta = + EagerUtils::autograd_meta(&(target_tensors[0])); + auto_grad_meta->SetGradNode( + std::dynamic_pointer_cast(node0_ptr)); + auto_grad_meta->SetSingleOutRankWithSlot(0, 0); + auto_grad_meta->SetStopGradient(false); + // Connect Node0 -> Node1 via Edge + auto meta0 = egr::AutogradMeta(); + meta0.SetStopGradient(false); + meta0.SetSingleOutRankWithSlot(0, 0); + meta0.SetGradNode(node1_ptr); + std::vector res0 = {&meta0}; + node0_ptr->AddEdges(&res0, 0); + + AutogradMeta* auto_grad_meta1 = EagerUtils::autograd_meta(&leaf_tensor); + // Connect Tensor and AccumulationNode via AutoGradMeta + auto acc_node_ptr = + std::make_shared(auto_grad_meta1); + + auto_grad_meta1->SetGradNode( + std::dynamic_pointer_cast(acc_node_ptr)); + auto_grad_meta1->SetSingleOutRankWithSlot(0, 0); + + auto_grad_meta1->SetStopGradient(false); + std::vector res1 = {auto_grad_meta1}; + node1_ptr->AddEdges(&res1, 0); + } + + // Use Empty Grad Tensor + auto result = Grad(target_tensors, {leaf_tensor}, {}); + + // Check Output Value + eager_test::CompareTensorWithValue(result[0], 50.0); +} + +/* + Node2 + | | +Node0 Node1 + | | + in0 in1 +*/ +TEST(Grad, WithAccumulation) { + // Prepare Device Contexts + eager_test::InitEnv(paddle::platform::CPUPlace()); + + // Prepare Inputs + paddle::framework::DDim ddim = phi::make_ddim({4, 16, 16, 32}); + + // Create Target Tensor + std::vector target_tensors; + paddle::experimental::Tensor tensor0 = egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor tensor1 = egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 1.0 /*value*/, false /*is_leaf*/); + target_tensors.emplace_back(std::move(tensor0)); + target_tensors.emplace_back(std::move(tensor1)); + + // Create Grad Tensor + std::vector grad_tensors; + paddle::experimental::Tensor grad_tensor0 = + egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 5.0 /*value*/, false /*is_leaf*/); + paddle::experimental::Tensor grad_tensor1 = + egr_utils_api::CreateTensorWithValue( + ddim, paddle::platform::CPUPlace(), phi::DataType::FLOAT32, + phi::DataLayout::NCHW, 10.0 /*value*/, false /*is_leaf*/); + grad_tensors.emplace_back(std::move(grad_tensor0)); + grad_tensors.emplace_back(std::move(grad_tensor1)); + + paddle::experimental::Tensor leaf_tensor; + { + // Create Node0 + auto node0_ptr = std::make_shared(1, 1); + node0_ptr->SetAttributes_scale(5.0 /*scale*/); + node0_ptr->SetDefaultGradInOutMeta(); + + // Create Node1 + auto node1_ptr = std::make_shared(1, 1); + node1_ptr->SetAttributes_scale(10.0 /*scale*/); + node1_ptr->SetDefaultGradInOutMeta(); + // Create Node2 + auto node2_ptr = std::make_shared(1, 1); + node2_ptr->SetAttributes_scale(20.0 /*scale*/); + node2_ptr->SetDefaultGradInOutMeta(); + // Connect Inp0 and Node0 via AutoGradMeta + AutogradMeta* auto_grad_meta0 = + EagerUtils::autograd_meta(&(target_tensors[0])); + auto_grad_meta0->SetGradNode( + std::dynamic_pointer_cast(node0_ptr)); + auto_grad_meta0->SetSingleOutRankWithSlot(0, 0); + auto_grad_meta0->SetStopGradient(false); + // Connect Inp1 and Node1 via AutoGradMeta + AutogradMeta* auto_grad_meta1 = + EagerUtils::autograd_meta(&(target_tensors[1])); + auto_grad_meta1->SetGradNode( + std::dynamic_pointer_cast(node1_ptr)); + auto_grad_meta1->SetSingleOutRankWithSlot(0, 0); + auto_grad_meta1->SetStopGradient(false); + + // Connect Node0 -> Node2 via Edge + auto meta0 = egr::AutogradMeta(); + meta0.SetStopGradient(false); + meta0.SetSingleOutRankWithSlot(0, 0); + meta0.SetGradNode(node2_ptr); + std::vector res0 = {&meta0}; + node0_ptr->AddEdges(&res0, 0); + + // Connect Node1 -> Node2 via Edge + auto meta1 = egr::AutogradMeta(); + meta1.SetStopGradient(false); + meta1.SetSingleOutRankWithSlot(0, 0); + meta1.SetGradNode(node2_ptr); + std::vector res1 = {&meta1}; + node1_ptr->AddEdges(&res1, 0); + + AutogradMeta* auto_grad_meta2 = EagerUtils::autograd_meta(&leaf_tensor); + // Connect Tensor and AccumulationNode via AutoGradMeta + auto acc_node_ptr = + std::make_shared(auto_grad_meta2); + + auto_grad_meta2->SetGradNode( + std::dynamic_pointer_cast(acc_node_ptr)); + auto_grad_meta2->SetSingleOutRankWithSlot(0, 0); + + auto_grad_meta2->SetStopGradient(false); + std::vector res2 = {auto_grad_meta2}; + node2_ptr->AddEdges(&res2, 0); + } + + auto result = Grad(target_tensors, {leaf_tensor}, grad_tensors); + + eager_test::CompareTensorWithValue(result[0], 2500.0); +} + +} // namespace egr diff --git a/paddle/fluid/eager/tests/task_tests/hook_test.cc b/paddle/fluid/eager/tests/task_tests/hook_test.cc index d546df4ed087a99a28096a5336fab3826991534a..2c53fc89f650e36f1435c7e1e805453fe7822cf2 100644 --- a/paddle/fluid/eager/tests/task_tests/hook_test.cc +++ b/paddle/fluid/eager/tests/task_tests/hook_test.cc @@ -132,7 +132,7 @@ TEST(RetainGrad, HookBeforeRetainGrad) { leaf_tensor); // result: 4.0*5.0 + 3.0 = 23.0 } - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(target_tensor, 4.0); eager_test::CompareGradTensorWithValue(leaf_tensor, 23.0); @@ -199,7 +199,7 @@ TEST(RetainGrad, HookAfterRetainGrad) { leaf_tensor, std::make_shared(hook_function)); } - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(target_tensor, 1.0); eager_test::CompareGradTensorWithValue(leaf_tensor, 23.0); } diff --git a/paddle/fluid/eager/tests/task_tests/hook_test_intermidiate.cc b/paddle/fluid/eager/tests/task_tests/hook_test_intermidiate.cc index 56813c498d2410caa452da7a334c393b230c65bf..0ee171c73c6600b95b9b093ef7e818855f53002d 100644 --- a/paddle/fluid/eager/tests/task_tests/hook_test_intermidiate.cc +++ b/paddle/fluid/eager/tests/task_tests/hook_test_intermidiate.cc @@ -108,7 +108,7 @@ void test_sigmoid(bool is_remove_gradient_hook) { } VLOG(6) << "Runing Backward"; - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); VLOG(6) << "Finish Backward"; eager_test::CompareGradTensorWithValue( @@ -166,7 +166,7 @@ void test_elementwiseAdd(bool is_remove_gradient_hook) { grad_node_tmp->RemoveGradientHook(hook_id); } - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(X, 1.0); eager_test::CompareGradTensorWithValue( @@ -224,7 +224,7 @@ void test_matmul(bool is_remove_gradient_hook) { grad_node_tmp->RemoveGradientHook(hook_id); } - RunBackward(target_tensors, {}); + Backward(target_tensors, {}); eager_test::CompareGradTensorWithValue(X, 2.0 * 20); eager_test::CompareGradTensorWithValue( diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index d99624e49324853d513a20a725c1a3d12b6aaab5..4eaa64d3ac659ca0ec76083b70855d8b6b241556 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -370,8 +370,8 @@ class GradNodeRunProgram : public egr::GradNodeBase { ~GradNodeRunProgram() override = default; // Functor: perform backward computations virtual std::vector> operator()( - const std::vector> &grads) - override { + const std::vector> &grads, + bool create_graph) override { VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; PADDLE_ENFORCE_EQ( grads.size(), 1, @@ -415,6 +415,12 @@ class GradNodeRunProgram : public egr::GradNodeBase { // return {x_grad, details::DereferenceTensors(params_grad_ptr)}; } + void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } + bool IsTensorWrappersCleared() override { + VLOG(6) << "Do nothing here now"; + return false; + } + // SetAttrMap void SetAttrMap(const paddle::framework::AttributeMap &attrs) { attrs_ = attrs; diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index e110432c67d395c865d934a47eaa4a803053db8b..c9e80c7b4b407456fc962f508ae441a9c07914b2 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -122,13 +122,33 @@ static PyObject* eager_api_run_backward(PyObject* self, PyObject* args, EAGER_TRY auto tensors = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); auto grad_tensors = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 1), 1); - egr::RunBackward(tensors, grad_tensors, - CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 2), 2)); + egr::Backward(tensors, grad_tensors, + CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 2), 2)); Py_INCREF(Py_None); return Py_None; EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* eager_api_run_partial_grad(PyObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + auto tensors = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0); + auto inputs = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 1), 1); + auto grad_tensors = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 2), 2); + auto retain_graph = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3); + auto create_graph = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 4), 4); + auto only_inputs = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 5), 5); + auto allow_unused = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 6), 6); + auto no_grad_vars = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 7), 7); + + std::vector result = + egr::Grad(tensors, inputs, grad_tensors, retain_graph, create_graph, + only_inputs, allow_unused, no_grad_vars); + VLOG(1) << " in eager_api_run_partial_grad, after runing egr::Grad"; + return ToPyObject(result, true /* return_py_none_if_not_initialize */); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* eager_api_tensor_copy(PyObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY @@ -452,6 +472,9 @@ PyMethodDef variable_functions[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"run_backward", (PyCFunction)(void (*)(void))eager_api_run_backward, METH_VARARGS | METH_KEYWORDS, NULL}, + {"run_partial_grad", + (PyCFunction)(void (*)(void))eager_api_run_partial_grad, + METH_VARARGS | METH_KEYWORDS, NULL}, {"_run_custom_op", (PyCFunction)(void (*)(void))eager_api_run_costum_op, METH_VARARGS | METH_KEYWORDS, NULL}, {"tensor_copy", (PyCFunction)(void (*)(void))eager_api_tensor_copy, diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 217edad0c0a105cc649c6c8c4433b0c8eab0119b..97bb32630d71368de2dee205fbef186a8551d9c7 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -492,20 +492,26 @@ PyObject* ToPyObject(const std::vector& value) { return result; } -PyObject* ToPyObject(const std::vector& value) { +PyObject* ToPyObject(const std::vector& value, + bool return_py_none_if_not_initialize) { PyObject* result = PyList_New((Py_ssize_t)value.size()); for (size_t i = 0; i < value.size(); i++) { - PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0); - if (obj) { - auto v = reinterpret_cast(obj); - new (&(v->tensor)) paddle::experimental::Tensor(); - v->tensor = value[i]; + if (!value[i].initialized() && return_py_none_if_not_initialize) { + Py_INCREF(Py_None); + PyList_SET_ITEM(result, static_cast(i), Py_None); } else { - PADDLE_THROW(platform::errors::Fatal( - "tp_alloc return null, can not new a PyObject.")); + PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0); + if (obj) { + auto v = reinterpret_cast(obj); + new (&(v->tensor)) paddle::experimental::Tensor(); + v->tensor = value[i]; + } else { + PADDLE_THROW(platform::errors::Fatal( + "tp_alloc return null, can not new a PyObject.")); + } + PyList_SET_ITEM(result, static_cast(i), obj); } - PyList_SET_ITEM(result, static_cast(i), obj); } return result; diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 2187555e1c3c7f64bd864e4212bfc6ebe1fb1684..1c4e2ab69a5ecba1209a11651c3c11972dff565c 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -68,7 +68,8 @@ PyObject* ToPyObject(const std::vector& value); PyObject* ToPyObject(const std::vector& value); PyObject* ToPyObject(const std::vector& value); PyObject* ToPyObject(const std::vector& value); -PyObject* ToPyObject(const std::vector& value); +PyObject* ToPyObject(const std::vector& value, + bool return_py_none_if_not_initialize = false); PyObject* ToPyObject(const platform::Place& value); PyObject* ToPyObject(const framework::LoDTensor* value); PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype); diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 8149d69d36a27fadcefa8dc6b6ff1dd89792e29e..9439982858530e1e81156be4b32ef2d91dc4a33a 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -565,16 +565,25 @@ def grad(outputs, if isinstance(in_out_list, (list, tuple)): assert len(in_out_list) > 0, "{} cannot be empty".format(name) for each_var in in_out_list: - assert isinstance( - each_var, - core.VarBase), "Elements of {} must be Variable".format( - name) + if core._in_eager_mode(): + assert isinstance( + each_var, core.eager. + Tensor), "Elements of {} must be Tensor".format(name) + else: + assert isinstance( + each_var, + core.VarBase), "Elements of {} must be Variable".format( + name) return in_out_list else: - assert isinstance( - in_out_list, - core.VarBase), "{} must be Variable or list of Variable".format( - name) + if core._in_eager_mode(): + assert isinstance( + in_out_list, core.eager. + Tensor), "{} must be Tensor or list of Tensor".format(name) + else: + assert isinstance( + in_out_list, core.VarBase + ), "{} must be Variable or list of Variable".format(name) return [in_out_list] outputs = check_in_out(outputs, 'outputs') @@ -586,9 +595,14 @@ def grad(outputs, for each_var in grad_outputs: if each_var is not None: - assert isinstance( - each_var, core.VarBase - ), "grad_outputs must be None, a Variable or a list containing None or Variables" + if core._in_eager_mode(): + assert isinstance( + each_var, core.eager.Tensor + ), "grad_outputs must be None, a Variable or a list containing None or Variables" + else: + assert isinstance( + each_var, core.VarBase + ), "grad_outputs must be None, a Variable or a list containing None or Variables" else: grad_outputs = [] @@ -600,14 +614,27 @@ def grad(outputs, no_grad_vars = [] elif isinstance(no_grad_vars, core.VarBase): no_grad_vars = [no_grad_vars] + elif isinstance(no_grad_vars, core.eager.Tensor): + no_grad_vars = [no_grad_vars] elif isinstance(no_grad_vars, (list, tuple, set)): no_grad_vars = list(no_grad_vars) for var in no_grad_vars: - assert isinstance( - var, core.VarBase), "no_grad_vars can only contains Variable" + if core._in_eager_mode(): + assert isinstance( + var, + core.eager.Tensor), "no_grad_vars can only contains Tensor" + else: + assert isinstance( + var, + core.VarBase), "no_grad_vars can only contains Variable" else: - raise AssertionError( - "no_grad_vars must be None, Variable or list/tuple/set of Variables") + if core._in_eager_mode(): + raise AssertionError( + "no_grad_vars must be None, Tensor or list/tuple/set of Tensors") + else: + raise AssertionError( + "no_grad_vars must be None, Variable or list/tuple/set of Variables" + ) assert isinstance(create_graph, bool), "create_graph must be True or False" @@ -622,6 +649,11 @@ def grad(outputs, assert isinstance(only_inputs, bool), "only_inputs must be True or False" assert only_inputs, "only_inputs=False is not supported yet" + if core._in_eager_mode(): + return core.eager.run_partial_grad( + outputs, inputs, grad_outputs, retain_graph, create_graph, + only_inputs, allow_unused, no_grad_vars) + place = core.Place() place.set_place(framework._current_expected_place()) return core.dygraph_partial_grad(inputs, outputs, grad_outputs, diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index 27aec284de4cdebb5ebb9191bfb67d48c1b327f5..98ef339e04535bb943add02b6cf6efe490f0354b 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -52,7 +52,7 @@ class EagerScaleTestCase(unittest.TestCase): out_eager = core.eager.scale(data_eager, 1.0, 0.9, True, True) self.assertIsNone(data_eager.grad) out_eager.backward(grad_eager, False) - self.assertTrue(data_eager.grad._is_initialized()) + self.assertIsNotNone(data_eager.grad) self.assertTrue(np.array_equal(data_eager.grad.numpy(), input_data)) def test_retain_grad_and_run_backward_raises(self): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index cd4ba5b054264afca65d4c4d8359eb1854fbb658..7436e9eb7b12623296d7a714e742cc4212c4ca91 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,9 @@ from paddle.vision.models import resnet50, resnet101 import unittest from unittest import TestCase import numpy as np +import paddle.compat as cpt +from paddle.fluid.framework import _test_eager_guard +import paddle.fluid.core as core def _dygraph_guard_(func): @@ -40,6 +43,80 @@ def random_var(size, low=-1, high=1, dtype='float32'): return fluid.dygraph.to_variable(x_np) +class TestEagerGrad(TestCase): + def func_simple_example_eager_grad(self): + np.random.seed(2021) + paddle.set_device('cpu') + np_x = np.random.random((3, 3)) + np_y = np.random.random((3, 1)) + x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False) + y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False) + out = paddle.matmul(x, y) + dx = fluid.dygraph.grad(out, x) + + dout = np.ones_like(np_y) + expected_dx = np.matmul(dout, np.transpose(np_y)) + + # stop_gradient = !create_graph, create_graph default false + self.assertEqual(dx[0].stop_gradient, True) + self.assertTrue(np.allclose(dx[0].numpy(), expected_dx[0])) + + def test_simple_example_eager_grad(self): + with _test_eager_guard(): + self.func_simple_example_eager_grad() + self.func_simple_example_eager_grad() + + def func_simple_example_eager_grad_allow_unused(self): + np.random.seed(2021) + paddle.set_device('cpu') + np_x = np.random.random((3, 3)) + np_y = np.random.random((3, 1)) + np_z = np.random.random((3, 1)) + x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False) + y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False) + z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False) + out_z = paddle.nn.functional.sigmoid(z) + out = paddle.matmul(x, y) + + dx = fluid.dygraph.grad(out, [x, z], allow_unused=True) + dout = np.ones_like(np_y) + expected_dx = np.matmul(dout, np.transpose(np_y)) + self.assertTrue(np.allclose(dx[0].numpy(), expected_dx[0])) + # stop_gradient = !create_graph, create_graph default false + self.assertEqual(dx[0].stop_gradient, True) + # x is unused input in the graph + self.assertEqual(dx[1], None) + + def test_simple_example_eager_grad_allow_unused(self): + with _test_eager_guard(): + self.func_simple_example_eager_grad_allow_unused() + self.func_simple_example_eager_grad_allow_unused() + + def func_simple_example_eager_grad_not_allow_unused(self): + np.random.seed(2021) + paddle.set_device('cpu') + np_x = np.random.random((3, 3)) + np_y = np.random.random((3, 1)) + np_z = np.random.random((3, 1)) + x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False) + y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False) + z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False) + out_z = paddle.nn.functional.sigmoid(z) + out = paddle.matmul(x, y) + + try: + # allow_unused is false in default + dx = fluid.dygraph.grad(out, [x, z]) + except ValueError as e: + error_msg = cpt.get_exception_message(e) + assert error_msg.find("allow_unused") > 0 + + def test_simple_example_eager_grad_not_allow_unused(self): + with _test_eager_guard(): + self.func_simple_example_eager_grad_not_allow_unused() + self.func_simple_example_eager_grad_not_allow_unused() + + class TestDygraphDoubleGrad(TestCase): def setUp(self): self.sort_sum_gradient = False @@ -64,7 +141,7 @@ class TestDygraphDoubleGrad(TestCase): allow_unused=allow_unused) @dygraph_guard - def test_exception(self): + def func_exception(self): with self.assertRaises(AssertionError): self.grad(None, None) @@ -93,8 +170,13 @@ class TestDygraphDoubleGrad(TestCase): with self.assertRaises(AssertionError): self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1) + def test_exception(self): + with _test_eager_guard(): + self.func_exception() + self.func_exception() + @dygraph_guard - def test_simple_example(self): + def func_simple_example(self): x = random_var(self.shape) x.stop_gradient = False y = x + 1 @@ -123,8 +205,44 @@ class TestDygraphDoubleGrad(TestCase): self.assertNotEqual(grad_with_none_and_not_none.stop_gradient, create_graph) + def test_simple_example(self): + with _test_eager_guard(): + self.func_simple_example() + self.func_simple_example() + @dygraph_guard - def test_none_one_initial_gradient(self): + def func_example_no_grad_vars(self): + x = random_var(self.shape) + x_np = x.numpy() + numel = x_np.size + x.stop_gradient = False + + y1 = fluid.layers.relu(x) + y2 = fluid.layers.relu(x) + z = y1 + y2 + w = z * z + + w_mean = fluid.layers.reduce_mean(w) + del y1, z, w + + dx_actual, = self.grad( + [w_mean], [x], create_graph=True, no_grad_vars=[y2]) + + self.assertFalse(y2.stop_gradient) + self.assertFalse(dx_actual.stop_gradient) + + dx_expected = (1.0 / float(numel) * (np.maximum(x_np, 0) + y2.numpy()) * + (x_np > 0) * 2).astype('float32') + + self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) + + def test_example_no_grad_vars(self): + with _test_eager_guard(): + self.func_example_no_grad_vars() + self.func_example_no_grad_vars() + + @dygraph_guard + def func_none_one_initial_gradient(self): numel = 1 for s in self.shape: numel *= s @@ -190,8 +308,13 @@ class TestDygraphDoubleGrad(TestCase): np.array_equal(grad_z.numpy(), original_random_grad_z)) + def test_none_one_initial_gradient(self): + with _test_eager_guard(): + self.func_none_one_initial_gradient() + self.func_none_one_initial_gradient() + @dygraph_guard - def test_example_with_gradient_accumulation_and_create_graph(self): + def func_example_with_gradient_accumulation_and_create_graph(self): x = random_var(self.shape) x_np = x.numpy() numel = x_np.size @@ -214,25 +337,33 @@ class TestDygraphDoubleGrad(TestCase): (x_np > 0) * 2).astype('float32') self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward(retain_graph=True) - - x_grad_actual = x.gradient() - x_grad_expected = (2.0 / float(numel) * - (x_np + dx_expected * - (x_np > 0) * 2 / float(numel))).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) - - for i in range(5): + if core._in_eager_mode(): + pass + else: + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) loss.backward(retain_graph=True) + x_grad_actual = x.gradient() - x_grad_expected = (i + 2) * (2.0 / float(numel) * ( + x_grad_expected = (2.0 / float(numel) * ( x_np + dx_expected * (x_np > 0) * 2 / float(numel))).astype('float32') self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + for i in range(5): + loss.backward(retain_graph=True) + x_grad_actual = x.gradient() + x_grad_expected = (i + 2) * (2.0 / float(numel) * ( + x_np + dx_expected * + (x_np > 0) * 2 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + def test_example_with_gradient_accumulation_and_create_graph(self): + with _test_eager_guard(): + self.func_example_with_gradient_accumulation_and_create_graph() + self.func_example_with_gradient_accumulation_and_create_graph() + @dygraph_guard - def test_example_with_gradient_accumulation_and_no_grad_vars(self): + def func_example_with_gradient_accumulation_and_no_grad_vars(self): x = random_var(self.shape) x_np = x.numpy() numel = x_np.size @@ -256,17 +387,25 @@ class TestDygraphDoubleGrad(TestCase): (x_np > 0) * 2).astype('float32') self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + if core._in_eager_mode(): + pass + else: + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() - x_grad_actual = x.gradient() - x_grad_expected = (2.0 / float(numel) * - (x_np + dx_expected * - (x_np > 0) * 4 / float(numel))).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + x_grad_actual = x.gradient() + x_grad_expected = (2.0 / float(numel) * ( + x_np + dx_expected * + (x_np > 0) * 4 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + def test_example_with_gradient_accumulation_and_no_grad_vars(self): + with _test_eager_guard(): + self.func_example_with_gradient_accumulation_and_no_grad_vars() + self.func_example_with_gradient_accumulation_and_no_grad_vars() @dygraph_guard - def test_example_with_gradient_accumulation_and_not_create_graph(self): + def func_example_with_gradient_accumulation_and_not_create_graph(self): x = random_var(self.shape) x_np = x.numpy() numel = x_np.size @@ -289,12 +428,20 @@ class TestDygraphDoubleGrad(TestCase): self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + if core._in_eager_mode(): + pass + else: + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() - x_grad_actual = x.gradient() - x_grad_expected = (2.0 * x_np / float(numel)).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + x_grad_actual = x.gradient() + x_grad_expected = (2.0 * x_np / float(numel)).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + def test_example_with_gradient_accumulation_and_not_create_graph(self): + with _test_eager_guard(): + self.func_example_with_gradient_accumulation_and_not_create_graph() + self.func_example_with_gradient_accumulation_and_not_create_graph() class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad): @@ -304,7 +451,7 @@ class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad): class TestDygraphDoubleGradVisitedUniq(TestCase): - def test_compare(self): + def func_compare(self): value = np.random.uniform(-0.5, 0.5, 100).reshape(10, 2, 5).astype("float32") @@ -349,6 +496,11 @@ class TestDygraphDoubleGradVisitedUniq(TestCase): self.assertTrue(np.array_equal(grad_1, grad_2)) + def test_compare(self): + with _test_eager_guard(): + self.func_compare() + self.func_compare() + class TestRaiseNoDoubleGradOp(TestCase): def raise_no_grad_op(self): diff --git a/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py index 2ffe523ef6dda18a24813e702a1892c335ba6a68..531e9663a2b728a2871dff404425b063a0c47e67 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import unittest from unittest import TestCase import numpy as np import paddle +from paddle.fluid.framework import _test_eager_guard +import paddle.fluid.core as core def _dygraph_guard_(func): @@ -62,7 +64,7 @@ class TestDygraphDoubleGrad(TestCase): allow_unused=allow_unused) @dygraph_guard - def test_exception(self): + def func_exception(self): with self.assertRaises(AssertionError): self.grad(None, None) @@ -91,8 +93,13 @@ class TestDygraphDoubleGrad(TestCase): with self.assertRaises(AssertionError): self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1) + def test_exception(self): + with _test_eager_guard(): + self.func_exception() + self.func_exception() + @dygraph_guard - def test_simple_example(self): + def func_simple_example(self): x = random_var(self.shape) x.stop_gradient = False y = x + 1 @@ -121,8 +128,13 @@ class TestDygraphDoubleGrad(TestCase): self.assertNotEqual(grad_with_none_and_not_none.stop_gradient, create_graph) + def test_simple_example(self): + with _test_eager_guard(): + self.func_simple_example() + self.func_simple_example() + @dygraph_guard - def test_none_one_initial_gradient(self): + def func_none_one_initial_gradient(self): numel = 1 for s in self.shape: numel *= s @@ -188,8 +200,13 @@ class TestDygraphDoubleGrad(TestCase): np.array_equal(grad_z.numpy(), original_random_grad_z)) + def test_none_one_initial_gradient(self): + with _test_eager_guard(): + self.func_none_one_initial_gradient() + self.func_none_one_initial_gradient() + @dygraph_guard - def test_example_with_gradient_accumulation_and_create_graph(self): + def func_example_with_gradient_accumulation_and_create_graph(self): x = random_var(self.shape) x_np = x.numpy() numel = x_np.size @@ -212,17 +229,25 @@ class TestDygraphDoubleGrad(TestCase): (x_np > 0) * 2).astype('float32') self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + if core._in_eager_mode(): + pass + else: + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() - x_grad_actual = x.gradient() - x_grad_expected = (2.0 / float(numel) * - (x_np + dx_expected * - (x_np > 0) * 2 / float(numel))).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + x_grad_actual = x.gradient() + x_grad_expected = (2.0 / float(numel) * ( + x_np + dx_expected * + (x_np > 0) * 2 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + def test_example_with_gradient_accumulation_and_create_graph(self): + with _test_eager_guard(): + self.func_example_with_gradient_accumulation_and_create_graph() + self.func_example_with_gradient_accumulation_and_create_graph() @dygraph_guard - def test_example_with_gradient_accumulation_and_no_grad_vars(self): + def func_example_with_gradient_accumulation_and_no_grad_vars(self): x = random_var(self.shape) x_np = x.numpy() numel = x_np.size @@ -246,17 +271,25 @@ class TestDygraphDoubleGrad(TestCase): (x_np > 0) * 2).astype('float32') self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + if core._in_eager_mode(): + pass + else: + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() + + x_grad_actual = x.gradient() + x_grad_expected = (2.0 / float(numel) * ( + x_np + dx_expected * + (x_np > 0) * 4 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) - x_grad_actual = x.gradient() - x_grad_expected = (2.0 / float(numel) * - (x_np + dx_expected * - (x_np > 0) * 4 / float(numel))).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + def test_example_with_gradient_accumulation_and_no_grad_vars(self): + with _test_eager_guard(): + self.func_example_with_gradient_accumulation_and_no_grad_vars() + self.func_example_with_gradient_accumulation_and_no_grad_vars() @dygraph_guard - def test_example_with_gradient_accumulation_and_not_create_graph(self): + def func_example_with_gradient_accumulation_and_not_create_graph(self): x = random_var(self.shape) x_np = x.numpy() numel = x_np.size @@ -279,12 +312,20 @@ class TestDygraphDoubleGrad(TestCase): self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + if core._in_eager_mode(): + pass + else: + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() - x_grad_actual = x.gradient() - x_grad_expected = (2.0 * x_np / float(numel)).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + x_grad_actual = x.gradient() + x_grad_expected = (2.0 * x_np / float(numel)).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + def test_example_with_gradient_accumulation_and_not_create_graph(self): + with _test_eager_guard(): + self.func_example_with_gradient_accumulation_and_not_create_graph() + self.func_example_with_gradient_accumulation_and_not_create_graph() class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad):