未验证 提交 c62a7e25 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Fix edvr starganv2 (#43471)

* fix starganv2

* fix starganv2 stop_gradient end error

* fix edvr_starganv2

* fix mul kernel to fix optional ddx

* fix typo
上级 8cec1271
......@@ -1152,7 +1152,8 @@ static std::string GenerateGradNodeCreationContent(
size_t bwd_in_slot_num = out_vars.size();
size_t bwd_out_slot_num = in_vars.size();
const char* GRAD_OP_NODE_TEMPLATE =
" auto grad_node = std::shared_ptr<GradNode%s>(new GradNode%s(%d, "
" auto grad_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(%d, "
"%d));\n";
grad_node_creation_str += " // Create GradOpNode\n";
grad_node_creation_str +=
......@@ -2080,10 +2081,8 @@ static std::string GenerateSingleOpBase(
generated_grad_function_body +=
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> " +
hooked_grads +
" = "
"GradNode" +
fwd_op_type + "::ApplyGradientHooks(grads);\n";
hooked_grads + " = " + fwd_op_type +
"GradNodeCompat::ApplyGradientHooks(grads);\n";
// [Generation] Get Ins Map
std::unordered_set<std::string> dispensable_input_name_set;
......@@ -2547,7 +2546,7 @@ static std::string GenerateGradNodeCCContents(
*/
const char* EAGER_LOG_TEMPLATE =
" VLOG(3) << \"Running Eager Backward Node: GradNode%s\";\n";
" VLOG(3) << \"Running Eager Backward Node: %sGradNodeCompat\";\n";
std::string generated_grad_function_body =
paddle::string::Sprintf(EAGER_LOG_TEMPLATE, fwd_op_type);
......@@ -2616,7 +2615,7 @@ static std::string GenerateGradNodeCCContents(
const char* GRAD_FUNCTION_TEMPLATE =
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> "
"GradNode%s::operator()("
"%sGradNodeCompat::operator()("
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize>& grads, bool "
"create_graph, bool is_new_grad) {\n"
......@@ -2645,14 +2644,15 @@ static std::string GenerateGradNodeHeaderContents(
VLOG(6) << "Generating Grad Node Header";
const char* GRAD_NODE_TEMPLATE =
"class GradNode%s : public egr::GradNodeBase {\n"
"class %sGradNodeCompat : public egr::GradNodeBase {\n"
" public:\n"
" GradNode%s() : egr::GradNodeBase() { VLOG(7) << \" Construct "
"GradNode%s \"; }\n"
" GradNode%s(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : "
" %sGradNodeCompat() : egr::GradNodeBase() { VLOG(7) << \" Construct "
"%sGradNodeCompat \"; }\n"
" %sGradNodeCompat(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : "
"egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { VLOG(7) << \" "
"Construct GradNode%s \"; }\n"
" ~GradNode%s() override { VLOG(6) << \" Destruct GradNode%s \"; }\n"
"Construct %sGradNodeCompat \"; }\n"
" ~%sGradNodeCompat() override { VLOG(6) << \" Destruct "
"%sGradNodeCompat \"; }\n"
"\n"
" virtual "
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
......@@ -2667,11 +2667,11 @@ static std::string GenerateGradNodeHeaderContents(
"%s\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \"GradNode%sMid\"; } \n "
" std::string name() override { return \"%sGradNodeCompat\"; } \n "
"\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new "
"GradNode%s(*this));\n "
" auto copied_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(*this));\n "
" return copied_node;\n "
"}}\n "
"\n"
......
......@@ -147,7 +147,18 @@ def RemoveConstAndReference(string):
def GetGradNodeName(string):
return f"GradNode{string}Final"
def str2Hump(text):
arr = filter(None, text.split('_'))
res = ''
for i in arr:
res = res + i[0].upper() + i[1:]
return res
string = str2Hump(string)
if string.rfind("Grad") == (len(string) - 4):
string = string[:-4]
return f"{string}GradNodeFinal"
def GetDygraphForwardFunctionName(string):
......@@ -335,6 +346,7 @@ def ParseYamlInplaceInfo(string):
### Generator Base ###
########################
class FunctionGeneratorBase:
def __init__(self, forward_api_contents, namespace):
self.forward_api_contents = forward_api_contents
self.namespace = namespace
......@@ -423,8 +435,9 @@ class FunctionGeneratorBase:
input_type = forward_input[1]
input_pos = forward_input[2]
self.forward_inputs_position_map[
input_name] = [input_type, input_pos]
self.forward_inputs_position_map[input_name] = [
input_type, input_pos
]
for i in range(len(forward_returns_list)):
forward_return = forward_returns_list[i]
......@@ -432,11 +445,13 @@ class FunctionGeneratorBase:
return_type = forward_return[1]
return_pos = forward_return[2]
self.forward_outputs_position_map[
return_name] = [return_type, return_pos]
self.forward_outputs_position_map[return_name] = [
return_type, return_pos
]
class GeneratorBase:
def __init__(self, api_yaml_path):
self.namespace = ""
self.api_yaml_path = api_yaml_path
......
......@@ -411,6 +411,7 @@ def GenerateCoreOpInfoDefinition():
## Generator Class ##
#####################
class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace):
self.forward_api_contents = forward_api_contents
# Members from Parent:
......@@ -532,8 +533,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
max_input_position = max(max_input_position, pos)
for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position, AssertMessage(pos,
max_input_position)
assert pos > max_input_position, AssertMessage(
pos, max_input_position)
def BackwardValidationCheck(self):
backward_forward_inputs_map = self.backward_forward_inputs_map
......@@ -678,7 +679,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(forward_api_name)
grad_node_name = GetGradNodeName(self.backward_api_name)
# Helper
indent = GetIndent(2)
......@@ -845,6 +846,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)
......@@ -947,12 +949,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if is_inplaced and len(forward_outputs_position_map) == 1:
api_out_type = "auto&"
forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
num_outputs = len(
forward_outputs_position_map.keys()) - len(intermediate_outputs)
# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(function_name,
"api_result")
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
function_name, "api_result")
# Get Outputs
get_outputs_str = ""
......@@ -1007,8 +1009,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if pos == corresponding_pos:
has_corresponding_grad_output = True
if has_corresponding_grad_output or (
name in forward_inplace_map and
forward_api_name not in inplace_check_blacklist):
name in forward_inplace_map
and forward_api_name not in inplace_check_blacklist):
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
......@@ -1116,17 +1118,20 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
num_args = len(forward_inputs_position_map.keys()) + len(
forward_attrs_list)
num_args = len(
forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys())
final_state_fwd_api_name = "final_state_" + forward_api_name
core_ops_returns_info[
final_state_fwd_api_name] = ["" for i in range(num_returns)]
core_ops_args_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_returns_info[final_state_fwd_api_name] = [
"" for i in range(num_returns)
]
core_ops_args_info[final_state_fwd_api_name] = [
"" for i in range(num_args)
]
core_ops_args_type_info[final_state_fwd_api_name] = [
"" for i in range(num_args)
]
for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
......@@ -1159,6 +1164,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def __init__(self,
forward_api_contents,
grad_api_contents,
......@@ -1167,7 +1173,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)
# Record name mapping from forward_api_name to grad_api_names
# Record name mapping from forward_var_name to grad_var_names
self.to_next_grad_name_mapping = {} # {name : name}
# Generated Results
......@@ -1281,7 +1287,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name)
grad_node_name = GetGradNodeName(forward_op_name)
grad_node_name = GetGradNodeName(self.backward_api_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str, grad_node_name,
......@@ -1447,8 +1453,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""
# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(backward_api_name,
"returns")
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
backward_api_name, "returns")
# Prepare for Node Creation if Necessary
inputs_autograd_meta_str = ""
......@@ -1533,7 +1539,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
returns_str = f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"{indent}return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name)
grad_node_name = GetGradNodeName(self.backward_api_name)
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
......@@ -1560,6 +1566,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
class DygraphForwardAndNodesGenerator(GeneratorBase):
def __init__(self, api_yaml_path, backward_yaml_path):
# Parent members:
# self.namespace
......@@ -1617,8 +1624,9 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents)
node_generator = DygraphNodeGenerator(
forward_api_contents, backward_api_contents, namespace,
node_generator = DygraphNodeGenerator(forward_api_contents,
backward_api_contents,
namespace,
next_grad_api_contents)
node_generator.run()
self.node_declaration_str += node_generator.node_declaration_str + "\n"
......
......@@ -536,7 +536,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
const std::vector<paddle::experimental::Tensor>& inputs = {},
bool allow_unused = false,
const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
VLOG(6) << "Start Backward";
VLOG(3) << "Start Backward";
// *Gradient Hook should happen at node-level
// *Inplace version check should perform at node-level
......@@ -634,7 +634,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
}
VLOG(6) << "Update In degree Map for backward";
VLOG(3) << "Update In degree Map for backward";
// 3. Compute in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map =
getInDegreeMap(queue);
......@@ -654,7 +654,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// |- node(grads)
// |- Prepare for next node
// 3. Update queue
VLOG(6) << "Run Backward";
VLOG(3) << "Run Backward";
while (!queue.empty()) {
GradNodeBase* node = queue.front();
VLOG(6) << "Running GradNode:" << node->name();
......@@ -739,7 +739,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Since we make edge has as same rank as bwd outputs, we indexing them
// with the same rank(i, j)
auto next_node_shared = edge.GetMutableGradNode();
VLOG(3) << "Found pending node: " << next_node_shared->name();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
......@@ -826,7 +826,7 @@ void Backward(
const std::vector<paddle::experimental::Tensor>& tensors, // outputs
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph) {
VLOG(6) << "Run in Backward";
VLOG(3) << "Run in Backward";
paddle::platform::RecordEvent backward_record_event(
"backward", paddle::platform::TracerEventType::Operator, 1);
RunBackward(tensors, grad_tensors, retain_graph);
......@@ -839,7 +839,7 @@ std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
VLOG(6) << "Run in Grad";
VLOG(3) << "Run in Grad";
DuplicateCheck(inputs, true /* is_input */);
DuplicateCheck(tensors, false /* is_input */);
......
......@@ -225,7 +225,7 @@ void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
}
VLOG(6) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
......@@ -281,7 +281,7 @@ void GradNodeBase::SetGradOutMeta(
fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
}
VLOG(6) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
......
......@@ -68,6 +68,8 @@ void GradTensorHolder::CopyValueFromTensor(
// Fill 1.0, use full to support complex, one_like don't support it.
buffer_[slot_id][rank] =
paddle::experimental::full(t.shape(), 1, t.dtype(), t.place());
egr::EagerUtils::autograd_meta(&(buffer_[slot_id][rank]))
->SetStopGradient(false);
}
}
}
......@@ -75,8 +77,6 @@ void GradTensorHolder::CopyValueFromTensor(
void GradTensorHolder::add(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t,
bool create_graph) {
// TODO(jiabin): We need to deal with empty input_buffer with slot size not
// empty;
PADDLE_ENFORCE(slot_id < buffer_.size(),
paddle::platform::errors::Fatal(
"Invalid slot_id for GradTensorHolder::add() "
......
......@@ -1085,7 +1085,7 @@ void PartialGradEngine::Clear() {
void PartialGradEngine::Execute() {
PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied(
"PartialGradEngine has been destructed"));
VLOG(10) << "Starts to execute PartialGradEngine";
VLOG(3) << "Starts to execute PartialGradEngine";
results_ = task_->Run();
Clear();
}
......
......@@ -442,8 +442,14 @@ void MultiplyDoubleGradKernel(const Context& dev_ctx,
// (5) dx = dout * ddy
if (ddout) {
auto& place = *dev_ctx.eigen_device();
// size(ddout) > size(ddx), ddout can't use memory of ddx using inplace
if (ddout->numel() > ddx.get_ptr()->numel()) {
// size(ddout) > size(ddx) or we don't have ddx, ddout can't use memory of
// ddx using inplace
bool without_ddx = (ddx.get_ptr() == nullptr);
if (!without_ddx) {
without_ddx = (ddout->numel() > ddx.get_ptr()->numel());
}
if (without_ddx) {
phi::funcs::ElemwiseGradCompute<Context, T, MulGradDX<T>, MulGradDY<T>>(
dev_ctx,
ddx_safe,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册