未验证 提交 a729cb9c 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix bug of HandleComplexGradToRealGrad (#51302)

上级 e36cac06
...@@ -703,6 +703,7 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -703,6 +703,7 @@ void BuildOpFuncList(const platform::Place& place,
// op is not a operatorwithkernel, so direcly run OperatorBase::Run() // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
HandleOperatorBase( HandleOperatorBase(
place, var_scope, ops[i], &op_func_node, local_scope); place, var_scope, ops[i], &op_func_node, local_scope);
vec_func_list->emplace_back(op_func_node);
} else { } else {
VLOG(4) << "OP is not null"; VLOG(4) << "OP is not null";
auto op_with_kernel = const_cast<framework::OperatorWithKernel*>( auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
...@@ -871,20 +872,14 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -871,20 +872,14 @@ void BuildOpFuncList(const platform::Place& place,
} }
} }
// post-process grad_op.outputs if need cast complex grad into real // for debug nan/inf
// grad. if (FLAGS_check_nan_inf) {
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. VLOG(4) << "Check nan/inf";
if (IsGradOp(op_type) && framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
output_name_map,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope,
static_build);
} }
vec_func_list->emplace_back(op_func_node);
if (!op_func_node.inplace_back_map.empty()) { if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map; auto& m = op_func_node.inplace_back_map;
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
...@@ -903,10 +898,19 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -903,10 +898,19 @@ void BuildOpFuncList(const platform::Place& place,
} }
} }
// for debug nan/inf // post-process grad_op.outputs if need cast complex grad into real
if (FLAGS_check_nan_inf) { // grad.
VLOG(4) << "Check nan/inf"; // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place); if (IsGradOp(op_type) &&
framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
output_name_map,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope,
static_build);
} }
} }
} catch (platform::EnforceNotMet& ex) { } catch (platform::EnforceNotMet& ex) {
...@@ -927,8 +931,6 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -927,8 +931,6 @@ void BuildOpFuncList(const platform::Place& place,
VLOG(4) << "End run " << place << " " VLOG(4) << "End run " << place << " "
<< op_func_node.operator_base_->DebugStringEx(local_scope); << op_func_node.operator_base_->DebugStringEx(local_scope);
vec_func_list->emplace_back(op_func_node);
// gc--------------------------------------------- // gc---------------------------------------------
auto iter = unused_var_map.find(op); auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) { if (iter == unused_var_map.end()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册