未验证 提交 c62a7e25 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Fix edvr starganv2 (#43471)

* fix starganv2

* fix starganv2 stop_gradient end error

* fix edvr_starganv2

* fix mul kernel to fix optional ddx

* fix typo
上级 8cec1271
...@@ -1152,7 +1152,8 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1152,7 +1152,8 @@ static std::string GenerateGradNodeCreationContent(
size_t bwd_in_slot_num = out_vars.size(); size_t bwd_in_slot_num = out_vars.size();
size_t bwd_out_slot_num = in_vars.size(); size_t bwd_out_slot_num = in_vars.size();
const char* GRAD_OP_NODE_TEMPLATE = const char* GRAD_OP_NODE_TEMPLATE =
" auto grad_node = std::shared_ptr<GradNode%s>(new GradNode%s(%d, " " auto grad_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(%d, "
"%d));\n"; "%d));\n";
grad_node_creation_str += " // Create GradOpNode\n"; grad_node_creation_str += " // Create GradOpNode\n";
grad_node_creation_str += grad_node_creation_str +=
...@@ -2080,10 +2081,8 @@ static std::string GenerateSingleOpBase( ...@@ -2080,10 +2081,8 @@ static std::string GenerateSingleOpBase(
generated_grad_function_body += generated_grad_function_body +=
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, " " paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> " + "egr::kSlotSmallVectorSize> " +
hooked_grads + hooked_grads + " = " + fwd_op_type +
" = " "GradNodeCompat::ApplyGradientHooks(grads);\n";
"GradNode" +
fwd_op_type + "::ApplyGradientHooks(grads);\n";
// [Generation] Get Ins Map // [Generation] Get Ins Map
std::unordered_set<std::string> dispensable_input_name_set; std::unordered_set<std::string> dispensable_input_name_set;
...@@ -2547,7 +2546,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -2547,7 +2546,7 @@ static std::string GenerateGradNodeCCContents(
*/ */
const char* EAGER_LOG_TEMPLATE = const char* EAGER_LOG_TEMPLATE =
" VLOG(3) << \"Running Eager Backward Node: GradNode%s\";\n"; " VLOG(3) << \"Running Eager Backward Node: %sGradNodeCompat\";\n";
std::string generated_grad_function_body = std::string generated_grad_function_body =
paddle::string::Sprintf(EAGER_LOG_TEMPLATE, fwd_op_type); paddle::string::Sprintf(EAGER_LOG_TEMPLATE, fwd_op_type);
...@@ -2616,7 +2615,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -2616,7 +2615,7 @@ static std::string GenerateGradNodeCCContents(
const char* GRAD_FUNCTION_TEMPLATE = const char* GRAD_FUNCTION_TEMPLATE =
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, " "paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> " "egr::kSlotSmallVectorSize> "
"GradNode%s::operator()(" "%sGradNodeCompat::operator()("
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, " "paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize>& grads, bool " "egr::kSlotSmallVectorSize>& grads, bool "
"create_graph, bool is_new_grad) {\n" "create_graph, bool is_new_grad) {\n"
...@@ -2645,14 +2644,15 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2645,14 +2644,15 @@ static std::string GenerateGradNodeHeaderContents(
VLOG(6) << "Generating Grad Node Header"; VLOG(6) << "Generating Grad Node Header";
const char* GRAD_NODE_TEMPLATE = const char* GRAD_NODE_TEMPLATE =
"class GradNode%s : public egr::GradNodeBase {\n" "class %sGradNodeCompat : public egr::GradNodeBase {\n"
" public:\n" " public:\n"
" GradNode%s() : egr::GradNodeBase() { VLOG(7) << \" Construct " " %sGradNodeCompat() : egr::GradNodeBase() { VLOG(7) << \" Construct "
"GradNode%s \"; }\n" "%sGradNodeCompat \"; }\n"
" GradNode%s(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : " " %sGradNodeCompat(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : "
"egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { VLOG(7) << \" " "egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { VLOG(7) << \" "
"Construct GradNode%s \"; }\n" "Construct %sGradNodeCompat \"; }\n"
" ~GradNode%s() override { VLOG(6) << \" Destruct GradNode%s \"; }\n" " ~%sGradNodeCompat() override { VLOG(6) << \" Destruct "
"%sGradNodeCompat \"; }\n"
"\n" "\n"
" virtual " " virtual "
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, " "paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
...@@ -2667,11 +2667,11 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2667,11 +2667,11 @@ static std::string GenerateGradNodeHeaderContents(
"%s\n" "%s\n"
" SetIsTensorWrappersCleared(true);\n" " SetIsTensorWrappersCleared(true);\n"
" }\n" " }\n"
" std::string name() override { return \"GradNode%sMid\"; } \n " " std::string name() override { return \"%sGradNodeCompat\"; } \n "
"\n" "\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n " "std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new " " auto copied_node = std::shared_ptr<%sGradNodeCompat>(new "
"GradNode%s(*this));\n " "%sGradNodeCompat(*this));\n "
" return copied_node;\n " " return copied_node;\n "
"}}\n " "}}\n "
"\n" "\n"
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -147,7 +147,18 @@ def RemoveConstAndReference(string): ...@@ -147,7 +147,18 @@ def RemoveConstAndReference(string):
def GetGradNodeName(string): def GetGradNodeName(string):
return f"GradNode{string}Final"
def str2Hump(text):
arr = filter(None, text.split('_'))
res = ''
for i in arr:
res = res + i[0].upper() + i[1:]
return res
string = str2Hump(string)
if string.rfind("Grad") == (len(string) - 4):
string = string[:-4]
return f"{string}GradNodeFinal"
def GetDygraphForwardFunctionName(string): def GetDygraphForwardFunctionName(string):
...@@ -335,6 +346,7 @@ def ParseYamlInplaceInfo(string): ...@@ -335,6 +346,7 @@ def ParseYamlInplaceInfo(string):
### Generator Base ### ### Generator Base ###
######################## ########################
class FunctionGeneratorBase: class FunctionGeneratorBase:
def __init__(self, forward_api_contents, namespace): def __init__(self, forward_api_contents, namespace):
self.forward_api_contents = forward_api_contents self.forward_api_contents = forward_api_contents
self.namespace = namespace self.namespace = namespace
...@@ -357,7 +369,7 @@ class FunctionGeneratorBase: ...@@ -357,7 +369,7 @@ class FunctionGeneratorBase:
# Special Op Attributes # Special Op Attributes
self.optional_inputs = [] #[name, ...] self.optional_inputs = [] #[name, ...]
self.no_need_buffers = [] #[name, ...] self.no_need_buffers = [] #[name, ...]
self.intermediate_outputs = [] #[name, ...] self.intermediate_outputs = [] #[name, ...]
self.forward_inplace_map = {} #{name : name, ...} self.forward_inplace_map = {} #{name : name, ...}
def ParseForwardInplaceInfo(self): def ParseForwardInplaceInfo(self):
...@@ -423,8 +435,9 @@ class FunctionGeneratorBase: ...@@ -423,8 +435,9 @@ class FunctionGeneratorBase:
input_type = forward_input[1] input_type = forward_input[1]
input_pos = forward_input[2] input_pos = forward_input[2]
self.forward_inputs_position_map[ self.forward_inputs_position_map[input_name] = [
input_name] = [input_type, input_pos] input_type, input_pos
]
for i in range(len(forward_returns_list)): for i in range(len(forward_returns_list)):
forward_return = forward_returns_list[i] forward_return = forward_returns_list[i]
...@@ -432,11 +445,13 @@ class FunctionGeneratorBase: ...@@ -432,11 +445,13 @@ class FunctionGeneratorBase:
return_type = forward_return[1] return_type = forward_return[1]
return_pos = forward_return[2] return_pos = forward_return[2]
self.forward_outputs_position_map[ self.forward_outputs_position_map[return_name] = [
return_name] = [return_type, return_pos] return_type, return_pos
]
class GeneratorBase: class GeneratorBase:
def __init__(self, api_yaml_path): def __init__(self, api_yaml_path):
self.namespace = "" self.namespace = ""
self.api_yaml_path = api_yaml_path self.api_yaml_path = api_yaml_path
......
...@@ -411,6 +411,7 @@ def GenerateCoreOpInfoDefinition(): ...@@ -411,6 +411,7 @@ def GenerateCoreOpInfoDefinition():
## Generator Class ## ## Generator Class ##
##################### #####################
class DygraphFunctionGeneratorBase(FunctionGeneratorBase): class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace): def __init__(self, forward_api_contents, grad_api_contents, namespace):
self.forward_api_contents = forward_api_contents self.forward_api_contents = forward_api_contents
# Members from Parent: # Members from Parent:
...@@ -532,8 +533,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -532,8 +533,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
max_input_position = max(max_input_position, pos) max_input_position = max(max_input_position, pos)
for _, _, _, pos in forward_attrs_list: for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position, AssertMessage(pos, assert pos > max_input_position, AssertMessage(
max_input_position) pos, max_input_position)
def BackwardValidationCheck(self): def BackwardValidationCheck(self):
backward_forward_inputs_map = self.backward_forward_inputs_map backward_forward_inputs_map = self.backward_forward_inputs_map
...@@ -678,7 +679,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -678,7 +679,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
# Node Construction # Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys()) num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys()) num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(forward_api_name) grad_node_name = GetGradNodeName(self.backward_api_name)
# Helper # Helper
indent = GetIndent(2) indent = GetIndent(2)
...@@ -845,6 +846,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -845,6 +846,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace): def __init__(self, forward_api_contents, grad_api_contents, namespace):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace) grad_api_contents, namespace)
...@@ -947,12 +949,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -947,12 +949,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if is_inplaced and len(forward_outputs_position_map) == 1: if is_inplaced and len(forward_outputs_position_map) == 1:
api_out_type = "auto&" api_out_type = "auto&"
forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
num_outputs = len(forward_outputs_position_map.keys()) - len( num_outputs = len(
intermediate_outputs) forward_outputs_position_map.keys()) - len(intermediate_outputs)
# Check Nan and Inf # Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(function_name, check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
"api_result") function_name, "api_result")
# Get Outputs # Get Outputs
get_outputs_str = "" get_outputs_str = ""
...@@ -1007,8 +1009,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1007,8 +1009,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if pos == corresponding_pos: if pos == corresponding_pos:
has_corresponding_grad_output = True has_corresponding_grad_output = True
if has_corresponding_grad_output or ( if has_corresponding_grad_output or (
name in forward_inplace_map and name in forward_inplace_map
forward_api_name not in inplace_check_blacklist): and forward_api_name not in inplace_check_blacklist):
input_autograd_meta_name = GetAutoGradMetaName(name) input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
...@@ -1116,17 +1118,20 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1116,17 +1118,20 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list forward_attrs_list = self.forward_attrs_list
num_args = len(forward_inputs_position_map.keys()) + len( num_args = len(
forward_attrs_list) forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys()) num_returns = len(forward_outputs_position_map.keys())
final_state_fwd_api_name = "final_state_" + forward_api_name final_state_fwd_api_name = "final_state_" + forward_api_name
core_ops_returns_info[ core_ops_returns_info[final_state_fwd_api_name] = [
final_state_fwd_api_name] = ["" for i in range(num_returns)] "" for i in range(num_returns)
core_ops_args_info[ ]
final_state_fwd_api_name] = ["" for i in range(num_args)] core_ops_args_info[final_state_fwd_api_name] = [
core_ops_args_type_info[ "" for i in range(num_args)
final_state_fwd_api_name] = ["" for i in range(num_args)] ]
core_ops_args_type_info[final_state_fwd_api_name] = [
"" for i in range(num_args)
]
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name core_ops_args_info[final_state_fwd_api_name][pos] = name
...@@ -1159,6 +1164,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1159,6 +1164,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
class DygraphNodeGenerator(DygraphFunctionGeneratorBase): class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def __init__(self, def __init__(self,
forward_api_contents, forward_api_contents,
grad_api_contents, grad_api_contents,
...@@ -1167,7 +1173,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1167,7 +1173,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace) grad_api_contents, namespace)
# Record name mapping from forward_api_name to grad_api_names # Record name mapping from forward_var_name to grad_var_names
self.to_next_grad_name_mapping = {} # {name : name} self.to_next_grad_name_mapping = {} # {name : name}
# Generated Results # Generated Results
...@@ -1281,7 +1287,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1281,7 +1287,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name) RemoveConstAndReference(atype), saved_attr_name)
grad_node_name = GetGradNodeName(forward_op_name) grad_node_name = GetGradNodeName(self.backward_api_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str, grad_node_name, grad_node_name, clear_tensor_wrapper_str, grad_node_name,
...@@ -1447,8 +1453,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1447,8 +1453,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});""" {indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""
# Check Nan and Inf # Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(backward_api_name, check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
"returns") backward_api_name, "returns")
# Prepare for Node Creation if Necessary # Prepare for Node Creation if Necessary
inputs_autograd_meta_str = "" inputs_autograd_meta_str = ""
...@@ -1533,7 +1539,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1533,7 +1539,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
returns_str = f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" returns_str = f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"{indent}return returns;\n" returns_str += f"{indent}return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name) grad_node_name = GetGradNodeName(self.backward_api_name)
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format( self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name, grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
...@@ -1560,6 +1566,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1560,6 +1566,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
class DygraphForwardAndNodesGenerator(GeneratorBase): class DygraphForwardAndNodesGenerator(GeneratorBase):
def __init__(self, api_yaml_path, backward_yaml_path): def __init__(self, api_yaml_path, backward_yaml_path):
# Parent members: # Parent members:
# self.namespace # self.namespace
...@@ -1617,9 +1624,10 @@ class DygraphForwardAndNodesGenerator(GeneratorBase): ...@@ -1617,9 +1624,10 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
next_grad_api_contents = self.GetBackwardAPIContents( next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents) backward_api_contents)
node_generator = DygraphNodeGenerator( node_generator = DygraphNodeGenerator(forward_api_contents,
forward_api_contents, backward_api_contents, namespace, backward_api_contents,
next_grad_api_contents) namespace,
next_grad_api_contents)
node_generator.run() node_generator.run()
self.node_declaration_str += node_generator.node_declaration_str + "\n" self.node_declaration_str += node_generator.node_declaration_str + "\n"
self.node_definition_str += node_generator.node_definition_str + "\n" self.node_definition_str += node_generator.node_definition_str + "\n"
......
...@@ -536,7 +536,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -536,7 +536,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
const std::vector<paddle::experimental::Tensor>& inputs = {}, const std::vector<paddle::experimental::Tensor>& inputs = {},
bool allow_unused = false, bool allow_unused = false,
const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) { const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
VLOG(6) << "Start Backward"; VLOG(3) << "Start Backward";
// *Gradient Hook should happen at node-level // *Gradient Hook should happen at node-level
// *Inplace version check should perform at node-level // *Inplace version check should perform at node-level
...@@ -634,7 +634,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -634,7 +634,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue); GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
} }
VLOG(6) << "Update In degree Map for backward"; VLOG(3) << "Update In degree Map for backward";
// 3. Compute in_degree for each node // 3. Compute in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map = std::unordered_map<GradNodeBase*, int> node_in_degree_map =
getInDegreeMap(queue); getInDegreeMap(queue);
...@@ -654,7 +654,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -654,7 +654,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// |- node(grads) // |- node(grads)
// |- Prepare for next node // |- Prepare for next node
// 3. Update queue // 3. Update queue
VLOG(6) << "Run Backward"; VLOG(3) << "Run Backward";
while (!queue.empty()) { while (!queue.empty()) {
GradNodeBase* node = queue.front(); GradNodeBase* node = queue.front();
VLOG(6) << "Running GradNode:" << node->name(); VLOG(6) << "Running GradNode:" << node->name();
...@@ -739,7 +739,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -739,7 +739,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Since we make edge has as same rank as bwd outputs, we indexing them // Since we make edge has as same rank as bwd outputs, we indexing them
// with the same rank(i, j) // with the same rank(i, j)
auto next_node_shared = edge.GetMutableGradNode(); auto next_node_shared = edge.GetMutableGradNode();
VLOG(3) << "Found pending node: " << next_node_shared->name();
// Next node could be nullptr if it is leaf tensor with no // Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached // AccumulationNode attached
// Or it could also originated from dispensable inputs // Or it could also originated from dispensable inputs
...@@ -826,7 +826,7 @@ void Backward( ...@@ -826,7 +826,7 @@ void Backward(
const std::vector<paddle::experimental::Tensor>& tensors, // outputs const std::vector<paddle::experimental::Tensor>& tensors, // outputs
const std::vector<paddle::experimental::Tensor>& grad_tensors, const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph) { bool retain_graph) {
VLOG(6) << "Run in Backward"; VLOG(3) << "Run in Backward";
paddle::platform::RecordEvent backward_record_event( paddle::platform::RecordEvent backward_record_event(
"backward", paddle::platform::TracerEventType::Operator, 1); "backward", paddle::platform::TracerEventType::Operator, 1);
RunBackward(tensors, grad_tensors, retain_graph); RunBackward(tensors, grad_tensors, retain_graph);
...@@ -839,7 +839,7 @@ std::vector<paddle::experimental::Tensor> Grad( ...@@ -839,7 +839,7 @@ std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& grad_tensors, const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused, bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
const std::vector<paddle::experimental::Tensor>& no_grad_vars) { const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
VLOG(6) << "Run in Grad"; VLOG(3) << "Run in Grad";
DuplicateCheck(inputs, true /* is_input */); DuplicateCheck(inputs, true /* is_input */);
DuplicateCheck(tensors, false /* is_input */); DuplicateCheck(tensors, false /* is_input */);
......
...@@ -225,7 +225,7 @@ void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in, ...@@ -225,7 +225,7 @@ void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
fwd_in_meta->SetGradNode( fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta)); std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
} }
VLOG(6) << "Add Edges for slot: " << slot_rank << ", the Edge is from " VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") " << this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name() << " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")"; << " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
...@@ -281,7 +281,7 @@ void GradNodeBase::SetGradOutMeta( ...@@ -281,7 +281,7 @@ void GradNodeBase::SetGradOutMeta(
fwd_in_meta->SetGradNode( fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta)); std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
} }
VLOG(6) << "Add Edges for slot: " << slot_rank << ", the Edge is from " VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") " << this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name() << " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")"; << " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
......
...@@ -68,6 +68,8 @@ void GradTensorHolder::CopyValueFromTensor( ...@@ -68,6 +68,8 @@ void GradTensorHolder::CopyValueFromTensor(
// Fill 1.0, use full to support complex, one_like don't support it. // Fill 1.0, use full to support complex, one_like don't support it.
buffer_[slot_id][rank] = buffer_[slot_id][rank] =
paddle::experimental::full(t.shape(), 1, t.dtype(), t.place()); paddle::experimental::full(t.shape(), 1, t.dtype(), t.place());
egr::EagerUtils::autograd_meta(&(buffer_[slot_id][rank]))
->SetStopGradient(false);
} }
} }
} }
...@@ -75,8 +77,6 @@ void GradTensorHolder::CopyValueFromTensor( ...@@ -75,8 +77,6 @@ void GradTensorHolder::CopyValueFromTensor(
void GradTensorHolder::add(size_t slot_id, size_t rank, void GradTensorHolder::add(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t, const paddle::experimental::Tensor& t,
bool create_graph) { bool create_graph) {
// TODO(jiabin): We need to deal with empty input_buffer with slot size not
// empty;
PADDLE_ENFORCE(slot_id < buffer_.size(), PADDLE_ENFORCE(slot_id < buffer_.size(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Invalid slot_id for GradTensorHolder::add() " "Invalid slot_id for GradTensorHolder::add() "
......
...@@ -1085,7 +1085,7 @@ void PartialGradEngine::Clear() { ...@@ -1085,7 +1085,7 @@ void PartialGradEngine::Clear() {
void PartialGradEngine::Execute() { void PartialGradEngine::Execute() {
PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied( PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied(
"PartialGradEngine has been destructed")); "PartialGradEngine has been destructed"));
VLOG(10) << "Starts to execute PartialGradEngine"; VLOG(3) << "Starts to execute PartialGradEngine";
results_ = task_->Run(); results_ = task_->Run();
Clear(); Clear();
} }
......
...@@ -442,8 +442,14 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx, ...@@ -442,8 +442,14 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
// (5) dx = dout * ddy // (5) dx = dout * ddy
if (ddout) { if (ddout) {
auto& place = *dev_ctx.eigen_device(); auto& place = *dev_ctx.eigen_device();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace // size(ddout) > size(ddx) or we don't have ddx, ddout can't use memory of
if (ddout->numel() > ddx.get_ptr()->numel()) { // ddx using inplace
bool without_ddx = (ddx.get_ptr() == nullptr);
if (!without_ddx) {
without_ddx = (ddout->numel() > ddx.get_ptr()->numel());
}
if (without_ddx) {
phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>( phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>(
dev_ctx, dev_ctx,
ddx_safe, ddx_safe,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册