未验证 提交 a729cb9c 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix bug of HandleComplexGradToRealGrad (#51302)

上级 e36cac06
......@@ -703,6 +703,7 @@ void BuildOpFuncList(const platform::Place& place,
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
HandleOperatorBase(
place, var_scope, ops[i], &op_func_node, local_scope);
vec_func_list->emplace_back(op_func_node);
} else {
VLOG(4) << "OP is not null";
auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
......@@ -871,20 +872,14 @@ void BuildOpFuncList(const platform::Place& place,
}
}
// post-process grad_op.outputs if need cast complex grad into real
// grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (IsGradOp(op_type) &&
framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
output_name_map,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope,
static_build);
// for debug nan/inf
if (FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
}
vec_func_list->emplace_back(op_func_node);
if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map;
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
......@@ -903,10 +898,19 @@ void BuildOpFuncList(const platform::Place& place,
}
}
// for debug nan/inf
if (FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
// post-process grad_op.outputs if need cast complex grad into real
// grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (IsGradOp(op_type) &&
framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
output_name_map,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope,
static_build);
}
}
} catch (platform::EnforceNotMet& ex) {
......@@ -927,8 +931,6 @@ void BuildOpFuncList(const platform::Place& place,
VLOG(4) << "End run " << place << " "
<< op_func_node.operator_base_->DebugStringEx(local_scope);
vec_func_list->emplace_back(op_func_node);
// gc---------------------------------------------
auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册