未验证 提交 9c3a35b9 编写于 作者: 姜永久 提交者: GitHub

rm flags retain grad in pybind (#49888)

* rm flags_retain grad in pybind

* retain grads for xpu test

* set retain grad for xpu

* rm flag

* lint

---------
Co-authored-by: Nwanghuancoder <wanghuan29@baidu.com>
上级 ac84dce9
...@@ -99,7 +99,6 @@ paddle::experimental::Tensor add_n_ad_func( ...@@ -99,7 +99,6 @@ paddle::experimental::Tensor add_n_ad_func(
egr::EagerUtils::SetHistory(out_autograd_meta, grad_node); egr::EagerUtils::SetHistory(out_autograd_meta, grad_node);
} }
grad_node->SetGradInMeta(out, 0); grad_node->SetGradInMeta(out, 0);
egr::EagerUtils::CheckAndRetainGrad(out);
// Set TensorWrappers for Forward Outputs if needed // Set TensorWrappers for Forward Outputs if needed
} }
......
...@@ -162,7 +162,6 @@ paddle::experimental::Tensor conv2d_ad_func( ...@@ -162,7 +162,6 @@ paddle::experimental::Tensor conv2d_ad_func(
egr::EagerUtils::SetHistory(out_autograd_meta, grad_node); egr::EagerUtils::SetHistory(out_autograd_meta, grad_node);
} }
grad_node->SetGradInMeta(out, 0); grad_node->SetGradInMeta(out, 0);
egr::EagerUtils::CheckAndRetainGrad(out);
// Set TensorWrappers for Forward Outputs if needed // Set TensorWrappers for Forward Outputs if needed
} }
......
...@@ -159,8 +159,6 @@ Conv2dGradNodeFinal::operator()( ...@@ -159,8 +159,6 @@ Conv2dGradNodeFinal::operator()(
} }
grad_node->SetGradInMeta(grad_input, 0); grad_node->SetGradInMeta(grad_input, 0);
grad_node->SetGradInMeta(grad_filter, 1); grad_node->SetGradInMeta(grad_filter, 1);
egr::EagerUtils::CheckAndRetainGrad(grad_input);
egr::EagerUtils::CheckAndRetainGrad(grad_filter);
// Set TensorWrappers for Forward Outputs if needed // Set TensorWrappers for Forward Outputs if needed
} }
......
...@@ -432,7 +432,6 @@ fused_attention_dygraph_function( ...@@ -432,7 +432,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_QKVBiasOut, egr::EagerUtils::SetHistory(p_autograd_QKVBiasOut,
QKVBiasOut_accumulation_node); QKVBiasOut_accumulation_node);
QKVBiasOut_accumulation_node->SetGradInMeta(QKVBiasOut, 0); QKVBiasOut_accumulation_node->SetGradInMeta(QKVBiasOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKVBiasOut);
grad_node->SetGradOutMeta(QKVBiasOut, 11); grad_node->SetGradOutMeta(QKVBiasOut, 11);
} }
...@@ -446,7 +445,6 @@ fused_attention_dygraph_function( ...@@ -446,7 +445,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_SrcMaskOut, egr::EagerUtils::SetHistory(p_autograd_SrcMaskOut,
SrcMaskOut_accumulation_node); SrcMaskOut_accumulation_node);
SrcMaskOut_accumulation_node->SetGradInMeta(SrcMaskOut, 0); SrcMaskOut_accumulation_node->SetGradInMeta(SrcMaskOut, 0);
egr::EagerUtils::CheckAndRetainGrad(SrcMaskOut);
grad_node->SetGradOutMeta(SrcMaskOut, 12); grad_node->SetGradOutMeta(SrcMaskOut, 12);
} }
...@@ -473,7 +471,6 @@ fused_attention_dygraph_function( ...@@ -473,7 +471,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_LnOut, egr::EagerUtils::SetHistory(p_autograd_LnOut,
LnOut_accumulation_node); LnOut_accumulation_node);
LnOut_accumulation_node->SetGradInMeta(LnOut, 0); LnOut_accumulation_node->SetGradInMeta(LnOut, 0);
egr::EagerUtils::CheckAndRetainGrad(LnOut);
grad_node->SetGradOutMeta(LnOut, 13); grad_node->SetGradOutMeta(LnOut, 13);
} }
if (LnMean.initialized()) { if (LnMean.initialized()) {
...@@ -505,7 +502,6 @@ fused_attention_dygraph_function( ...@@ -505,7 +502,6 @@ fused_attention_dygraph_function(
BiasDropoutResidualOut_accumulation_node); BiasDropoutResidualOut_accumulation_node);
BiasDropoutResidualOut_accumulation_node->SetGradInMeta( BiasDropoutResidualOut_accumulation_node->SetGradInMeta(
BiasDropoutResidualOut, 0); BiasDropoutResidualOut, 0);
egr::EagerUtils::CheckAndRetainGrad(BiasDropoutResidualOut);
grad_node->SetGradOutMeta(BiasDropoutResidualOut, 14); grad_node->SetGradOutMeta(BiasDropoutResidualOut, 14);
} }
...@@ -524,17 +520,14 @@ fused_attention_dygraph_function( ...@@ -524,17 +520,14 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_CacheKVOut, 18); egr::EagerUtils::SetOutRankWithSlot(p_autograd_CacheKVOut, 18);
egr::EagerUtils::SetHistory(p_autograd_CacheKVOut, grad_node); egr::EagerUtils::SetHistory(p_autograd_CacheKVOut, grad_node);
grad_node->SetGradInMeta(CacheKVOut, 18); grad_node->SetGradInMeta(CacheKVOut, 18);
egr::EagerUtils::CheckAndRetainGrad(CacheKVOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 19); egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 19);
egr::EagerUtils::SetHistory(p_autograd_Y, grad_node); egr::EagerUtils::SetHistory(p_autograd_Y, grad_node);
grad_node->SetGradInMeta(Y, 19); grad_node->SetGradInMeta(Y, 19);
egr::EagerUtils::CheckAndRetainGrad(Y);
auto QKVOut_accumulation_node = auto QKVOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_QKVOut); std::make_shared<egr::GradNodeAccumulation>(p_autograd_QKVOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVOut, 0); egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKVOut, QKVOut_accumulation_node); egr::EagerUtils::SetHistory(p_autograd_QKVOut, QKVOut_accumulation_node);
QKVOut_accumulation_node->SetGradInMeta(QKVOut, 0); QKVOut_accumulation_node->SetGradInMeta(QKVOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKVOut);
grad_node->SetGradOutMeta(QKVOut, 15); grad_node->SetGradOutMeta(QKVOut, 15);
auto QKTVOut_accumulation_node = auto QKTVOut_accumulation_node =
...@@ -543,7 +536,6 @@ fused_attention_dygraph_function( ...@@ -543,7 +536,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_QKTVOut, egr::EagerUtils::SetHistory(p_autograd_QKTVOut,
QKTVOut_accumulation_node); QKTVOut_accumulation_node);
QKTVOut_accumulation_node->SetGradInMeta(QKTVOut, 0); QKTVOut_accumulation_node->SetGradInMeta(QKTVOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKTVOut);
grad_node->SetGradOutMeta(QKTVOut, 16); grad_node->SetGradOutMeta(QKTVOut, 16);
auto TransposeOut2_accumulation_node = auto TransposeOut2_accumulation_node =
...@@ -552,7 +544,6 @@ fused_attention_dygraph_function( ...@@ -552,7 +544,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_TransposeOut2, egr::EagerUtils::SetHistory(p_autograd_TransposeOut2,
TransposeOut2_accumulation_node); TransposeOut2_accumulation_node);
TransposeOut2_accumulation_node->SetGradInMeta(TransposeOut2, 0); TransposeOut2_accumulation_node->SetGradInMeta(TransposeOut2, 0);
egr::EagerUtils::CheckAndRetainGrad(TransposeOut2);
grad_node->SetGradOutMeta(TransposeOut2, 17); grad_node->SetGradOutMeta(TransposeOut2, 17);
auto QKOut_accumulation_node = auto QKOut_accumulation_node =
...@@ -560,7 +551,6 @@ fused_attention_dygraph_function( ...@@ -560,7 +551,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKOut, 0); egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKOut, QKOut_accumulation_node); egr::EagerUtils::SetHistory(p_autograd_QKOut, QKOut_accumulation_node);
QKOut_accumulation_node->SetGradInMeta(QKOut, 0); QKOut_accumulation_node->SetGradInMeta(QKOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKOut);
grad_node->SetGradOutMeta(QKOut, 18); grad_node->SetGradOutMeta(QKOut, 18);
auto SoftmaxOut_accumulation_node = auto SoftmaxOut_accumulation_node =
...@@ -569,7 +559,6 @@ fused_attention_dygraph_function( ...@@ -569,7 +559,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_SoftmaxOut, egr::EagerUtils::SetHistory(p_autograd_SoftmaxOut,
SoftmaxOut_accumulation_node); SoftmaxOut_accumulation_node);
SoftmaxOut_accumulation_node->SetGradInMeta(SoftmaxOut, 0); SoftmaxOut_accumulation_node->SetGradInMeta(SoftmaxOut, 0);
egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut);
grad_node->SetGradOutMeta(SoftmaxOut, 19); grad_node->SetGradOutMeta(SoftmaxOut, 19);
if (AttnDropoutOut.initialized()) { if (AttnDropoutOut.initialized()) {
...@@ -580,7 +569,6 @@ fused_attention_dygraph_function( ...@@ -580,7 +569,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut, egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut,
AttnDropoutOut_accumulation_node); AttnDropoutOut_accumulation_node);
AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0); AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0);
egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut);
grad_node->SetGradOutMeta(AttnDropoutOut, 20); grad_node->SetGradOutMeta(AttnDropoutOut, 20);
} }
...@@ -590,7 +578,6 @@ fused_attention_dygraph_function( ...@@ -590,7 +578,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_FMHAOut, egr::EagerUtils::SetHistory(p_autograd_FMHAOut,
FMHAOut_accumulation_node); FMHAOut_accumulation_node);
FMHAOut_accumulation_node->SetGradInMeta(FMHAOut, 0); FMHAOut_accumulation_node->SetGradInMeta(FMHAOut, 0);
egr::EagerUtils::CheckAndRetainGrad(FMHAOut);
grad_node->SetGradOutMeta(FMHAOut, 21); grad_node->SetGradOutMeta(FMHAOut, 21);
auto OutLinearOut_accumulation_node = auto OutLinearOut_accumulation_node =
...@@ -599,7 +586,6 @@ fused_attention_dygraph_function( ...@@ -599,7 +586,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_OutLinearOut, egr::EagerUtils::SetHistory(p_autograd_OutLinearOut,
OutLinearOut_accumulation_node); OutLinearOut_accumulation_node);
OutLinearOut_accumulation_node->SetGradInMeta(OutLinearOut, 0); OutLinearOut_accumulation_node->SetGradInMeta(OutLinearOut, 0);
egr::EagerUtils::CheckAndRetainGrad(OutLinearOut);
grad_node->SetGradOutMeta(OutLinearOut, 22); grad_node->SetGradOutMeta(OutLinearOut, 22);
} }
} }
......
...@@ -221,7 +221,6 @@ fused_bias_dropout_residual_layer_norm_dygraph_function( ...@@ -221,7 +221,6 @@ fused_bias_dropout_residual_layer_norm_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 4); egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 4);
egr::EagerUtils::SetHistory(p_autograd_Y, grad_node); egr::EagerUtils::SetHistory(p_autograd_Y, grad_node);
grad_node->SetGradInMeta(Y, 4); grad_node->SetGradInMeta(Y, 4);
egr::EagerUtils::CheckAndRetainGrad(Y);
} }
} }
......
...@@ -363,7 +363,6 @@ fused_feedforward_dygraph_function( ...@@ -363,7 +363,6 @@ fused_feedforward_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 0); egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 0);
egr::EagerUtils::SetHistory(p_autograd_Out, grad_node); egr::EagerUtils::SetHistory(p_autograd_Out, grad_node);
grad_node->SetGradInMeta(Out, 0); grad_node->SetGradInMeta(Out, 0);
egr::EagerUtils::CheckAndRetainGrad(Out);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Dropout1Mask, 1); egr::EagerUtils::SetOutRankWithSlot(p_autograd_Dropout1Mask, 1);
grad_node->SetGradInMeta(Dropout1Mask, 1); grad_node->SetGradInMeta(Dropout1Mask, 1);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Dropout2Mask, 2); egr::EagerUtils::SetOutRankWithSlot(p_autograd_Dropout2Mask, 2);
......
...@@ -372,7 +372,6 @@ fused_gate_attention_dygraph_function( ...@@ -372,7 +372,6 @@ fused_gate_attention_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 7); egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 7);
egr::EagerUtils::SetHistory(p_autograd_Out, grad_node); egr::EagerUtils::SetHistory(p_autograd_Out, grad_node);
grad_node->SetGradInMeta(Out, 7); grad_node->SetGradInMeta(Out, 7);
egr::EagerUtils::CheckAndRetainGrad(Out);
} }
} }
......
...@@ -120,7 +120,6 @@ paddle::experimental::Tensor fused_gemm_epilogue_dygraph_function( ...@@ -120,7 +120,6 @@ paddle::experimental::Tensor fused_gemm_epilogue_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 0); egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 0);
egr::EagerUtils::SetHistory(p_autograd_Out, grad_node); egr::EagerUtils::SetHistory(p_autograd_Out, grad_node);
grad_node->SetGradInMeta(Out, 0); grad_node->SetGradInMeta(Out, 0);
egr::EagerUtils::CheckAndRetainGrad(Out);
} }
} }
......
...@@ -1305,15 +1305,6 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1305,15 +1305,6 @@ static std::string GenerateGradNodeCreationContent(
paddle::string::Sprintf(SET_GRAD_IN_META_TEMPLATE, paddle::string::Sprintf(SET_GRAD_IN_META_TEMPLATE,
LegalizeVarName(inplace_input_name), LegalizeVarName(inplace_input_name),
output_position); output_position);
// Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str += paddle::string::Sprintf(
RETAIN_GRAD_TEMPLATE, LegalizeVarName(inplace_input_name));
}
} else { } else {
const std::string& output_autograd_name = const std::string& output_autograd_name =
"p_autograd_" + LegalizeVarName(output_name); "p_autograd_" + LegalizeVarName(output_name);
...@@ -1363,15 +1354,6 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1363,15 +1354,6 @@ static std::string GenerateGradNodeCreationContent(
LegalizeVarName(output_name), LegalizeVarName(output_name),
output_position); output_position);
} }
// Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str += paddle::string::Sprintf(
RETAIN_GRAD_TEMPLATE, LegalizeVarName(output_name));
}
} }
} }
VLOG(6) << "Generated SetGradIn/OutMeta"; VLOG(6) << "Generated SetGradIn/OutMeta";
......
...@@ -280,8 +280,7 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{ ...@@ -280,8 +280,7 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{
{} {}
// SetGradOutMeta & SetEdges // SetGradOutMeta & SetEdges
{} {}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad // SetOutRank & SetHistory & SetGradInMeta
{}
{} {}
{} {}
{} {}
...@@ -300,8 +299,7 @@ HIHGER_ORDER_DERIVATIVE_VALUE_TEMPLATE = """ if(trace_backward) {{ ...@@ -300,8 +299,7 @@ HIHGER_ORDER_DERIVATIVE_VALUE_TEMPLATE = """ if(trace_backward) {{
{} {}
// SetGradOutMeta & SetEdges // SetGradOutMeta & SetEdges
{} {}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad // SetOutRank & SetHistory & SetGradInMeta
{}
{} {}
{} {}
{} {}
...@@ -987,7 +985,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -987,7 +985,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank_list = [] set_out_rank_list = []
set_history_list = [] set_history_list = []
set_grad_in_meta_list = [] set_grad_in_meta_list = []
set_retain_grad_list = []
num_outputs = len(forward_outputs_position_map.keys()) num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items(): for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name) output_autograd_meta_name = GetAutoGradMetaName(name)
...@@ -1002,19 +999,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -1002,19 +999,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_grad_in_meta = ( set_grad_in_meta = (
f"{indent}grad_node->SetGradInMeta({name}, {pos});" f"{indent}grad_node->SetGradInMeta({name}, {pos});"
) )
set_retain_grad = (
f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
)
set_out_rank_list.append(set_out_rank) set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history) set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta) set_grad_in_meta_list.append(set_grad_in_meta)
set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list) set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list) set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list) set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation" node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n" node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
...@@ -1029,7 +1021,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -1029,7 +1021,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank_str, set_out_rank_str,
set_history_str, set_history_str,
set_grad_in_meta_str, set_grad_in_meta_str,
set_retain_grad_str,
set_output_tensor_wrappers_str, set_output_tensor_wrappers_str,
) )
else: else:
...@@ -1043,7 +1034,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -1043,7 +1034,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank_str, set_out_rank_str,
set_history_str, set_history_str,
set_grad_in_meta_str, set_grad_in_meta_str,
set_retain_grad_str,
set_output_tensor_wrappers_str, set_output_tensor_wrappers_str,
) )
) )
......
...@@ -310,7 +310,6 @@ RunCustomOpNode::operator()( ...@@ -310,7 +310,6 @@ RunCustomOpNode::operator()(
egr::EagerUtils::SetOutRankWithSlot(&(outs_auto_grad_metas[i]), i); egr::EagerUtils::SetOutRankWithSlot(&(outs_auto_grad_metas[i]), i);
egr::EagerUtils::SetHistory(&(outs_auto_grad_metas[i]), grad_node); egr::EagerUtils::SetHistory(&(outs_auto_grad_metas[i]), grad_node);
grad_node->SetGradInMeta(out_tensors, i); grad_node->SetGradInMeta(out_tensors, i);
egr::EagerUtils::CheckAndRetainGrad(out_tensors);
} }
// Prepare Grad inputs with fwd outputs // Prepare Grad inputs with fwd outputs
......
...@@ -122,6 +122,5 @@ inline void run_program_ad_func( ...@@ -122,6 +122,5 @@ inline void run_program_ad_func(
// Set History for output set current Grad Node for // Set History for output set current Grad Node for
egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node); egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node);
egr::EagerUtils::CheckAndRetainGrad(deref_out);
} }
} }
...@@ -27,10 +27,6 @@ ...@@ -27,10 +27,6 @@
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor,
false,
"retain grad for all tensor");
namespace egr { namespace egr {
/** /**
* Implementation of Eager Utils. * Implementation of Eager Utils.
...@@ -409,35 +405,6 @@ std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper( ...@@ -409,35 +405,6 @@ std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper(
} }
return ret; return ret;
} }
// TODO(jiabin): remove all this when we fix all test using tmp grad
void EagerUtils::CheckAndRetainGrad(
const paddle::experimental::Tensor& tensor) {
VLOG(6) << "Check RetainGradForTensor: " << tensor.name();
if (FLAGS_retain_grad_for_all_tensor) {
VLOG(6) << "RetainGradForTensor: " << tensor.name();
egr::egr_utils_api::RetainGradForTensor(tensor);
}
}
void EagerUtils::CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor>& tensors) {
if (FLAGS_retain_grad_for_all_tensor) {
for (auto& tensor : tensors) {
VLOG(6) << "RetainGradForTensor: " << tensor.name();
egr::egr_utils_api::RetainGradForTensor(tensor);
}
}
}
void EagerUtils::CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor*>& tensors) {
if (FLAGS_retain_grad_for_all_tensor) {
for (auto& tensor : tensors) {
VLOG(6) << "RetainGradForTensor: " << tensor->name();
egr::egr_utils_api::RetainGradForTensor(*tensor);
}
}
}
std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode( std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor) { const paddle::experimental::Tensor& tensor) {
......
...@@ -223,14 +223,6 @@ class EagerUtils { ...@@ -223,14 +223,6 @@ class EagerUtils {
const std::vector<paddle::experimental::Tensor*>& out_var, const std::vector<paddle::experimental::Tensor*>& out_var,
std::vector<paddle::experimental::Tensor>* result); std::vector<paddle::experimental::Tensor>* result);
// end Intermidate needed.
static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor);
static void CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor>& tensors);
static void CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor*>& tensors);
static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode( static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor); const paddle::experimental::Tensor& tensor);
......
...@@ -575,7 +575,6 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -575,7 +575,6 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
egr::EagerUtils::SetOutRankWithSlot(&(outs_auto_grad_metas[i]), i); egr::EagerUtils::SetOutRankWithSlot(&(outs_auto_grad_metas[i]), i);
egr::EagerUtils::SetHistory(&(outs_auto_grad_metas[i]), grad_node); egr::EagerUtils::SetHistory(&(outs_auto_grad_metas[i]), grad_node);
grad_node->SetGradInMeta(out_tensors, i); grad_node->SetGradInMeta(out_tensors, i);
egr::EagerUtils::CheckAndRetainGrad(out_tensors);
} }
// Prepare Grad inputs with fwd outputs // Prepare Grad inputs with fwd outputs
......
...@@ -432,12 +432,10 @@ PyObject* pylayer_method_apply(PyObject* cls, ...@@ -432,12 +432,10 @@ PyObject* pylayer_method_apply(PyObject* cls,
for (auto t : outputs_tensor[i]) { for (auto t : outputs_tensor[i]) {
grad_node->SetGradInMeta(*t, i); grad_node->SetGradInMeta(*t, i);
} }
egr::EagerUtils::CheckAndRetainGrad(outputs_tensor[i]);
} else { } else {
egr::EagerUtils::SetOutRankWithSlot(outputs_autograd_meta[i][0], i); egr::EagerUtils::SetOutRankWithSlot(outputs_autograd_meta[i][0], i);
egr::EagerUtils::SetHistory(outputs_autograd_meta[i][0], grad_node); egr::EagerUtils::SetHistory(outputs_autograd_meta[i][0], grad_node);
grad_node->SetGradInMeta(*outputs_tensor[i][0], i); grad_node->SetGradInMeta(*outputs_tensor[i][0], i);
egr::EagerUtils::CheckAndRetainGrad(*outputs_tensor[i][0]);
} }
} }
VLOG(6) << "PyLayer construct backward node finish..."; VLOG(6) << "PyLayer construct backward node finish...";
......
...@@ -152,7 +152,6 @@ class TestFlipDoubleGradCheck(unittest.TestCase): ...@@ -152,7 +152,6 @@ class TestFlipDoubleGradCheck(unittest.TestCase):
gradient_checker.double_grad_check( gradient_checker.double_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps [data], out, x_init=[data_arr], place=place, eps=eps
) )
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph( gradient_checker.double_grad_check_for_dygraph(
self.flip_wrapper, [data], out, x_init=[data_arr], place=place self.flip_wrapper, [data], out, x_init=[data_arr], place=place
) )
...@@ -184,7 +183,6 @@ class TestFlipTripleGradCheck(unittest.TestCase): ...@@ -184,7 +183,6 @@ class TestFlipTripleGradCheck(unittest.TestCase):
gradient_checker.triple_grad_check( gradient_checker.triple_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps [data], out, x_init=[data_arr], place=place, eps=eps
) )
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph( gradient_checker.triple_grad_check_for_dygraph(
self.flip_wrapper, [data], out, x_init=[data_arr], place=place self.flip_wrapper, [data], out, x_init=[data_arr], place=place
) )
......
...@@ -20,7 +20,6 @@ import paddle ...@@ -20,7 +20,6 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
paddle.set_device('xpu') paddle.set_device('xpu')
paddle.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
unary_api_list = [ unary_api_list = [
paddle.nn.functional.elu, paddle.nn.functional.elu,
...@@ -102,6 +101,7 @@ class TestUnaryAPI(unittest.TestCase): ...@@ -102,6 +101,7 @@ class TestUnaryAPI(unittest.TestCase):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = api(x) out = api(x)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
...@@ -147,6 +147,7 @@ class TestReduceAPI(unittest.TestCase): ...@@ -147,6 +147,7 @@ class TestReduceAPI(unittest.TestCase):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, None)
out.retain_grads()
out.backward() out.backward()
...@@ -201,12 +202,15 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -201,12 +202,15 @@ class TestBinaryAPI(unittest.TestCase):
y = paddle.rand([]) y = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
y.stop_gradient = False y.stop_gradient = False
x.retain_grads()
y.retain_grads()
if isinstance(api, dict): if isinstance(api, dict):
out = api['func'](x, y) out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else: else:
out = api(x, y) out = api(x, y)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
...@@ -228,6 +232,7 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -228,6 +232,7 @@ class TestBinaryAPI(unittest.TestCase):
np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else: else:
out = api(x, y) out = api(x, y)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.shape, [2, 3, 4]) self.assertEqual(x.shape, [2, 3, 4])
...@@ -243,12 +248,15 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -243,12 +248,15 @@ class TestBinaryAPI(unittest.TestCase):
y = paddle.rand([2, 3, 4]) y = paddle.rand([2, 3, 4])
x.stop_gradient = False x.stop_gradient = False
y.stop_gradient = False y.stop_gradient = False
x.retain_grads()
y.retain_grads()
if isinstance(api, dict): if isinstance(api, dict):
out = api['func'](x, y) out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else: else:
out = api(x, y) out = api(x, y)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
...@@ -265,6 +273,7 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -265,6 +273,7 @@ class TestBinaryAPI(unittest.TestCase):
y = 0.5 y = 0.5
if isinstance(api, dict): if isinstance(api, dict):
out = getattr(paddle.Tensor, api['cls_method'])(x, y) out = getattr(paddle.Tensor, api['cls_method'])(x, y)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
...@@ -381,7 +390,9 @@ class TestSundryAPI(unittest.TestCase): ...@@ -381,7 +390,9 @@ class TestSundryAPI(unittest.TestCase):
def test_pow_factor(self): def test_pow_factor(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.pow(x, 2.0) out = paddle.pow(x, 2.0)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -391,7 +402,9 @@ class TestSundryAPI(unittest.TestCase): ...@@ -391,7 +402,9 @@ class TestSundryAPI(unittest.TestCase):
def test_cast(self): def test_cast(self):
x = paddle.full([], 1.0, 'float32') x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.cast(x, 'int32') out = paddle.cast(x, 'int32')
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -401,7 +414,9 @@ class TestSundryAPI(unittest.TestCase): ...@@ -401,7 +414,9 @@ class TestSundryAPI(unittest.TestCase):
def test_clip(self): def test_clip(self):
x = paddle.uniform([], None, -10, 10) x = paddle.uniform([], None, -10, 10)
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.clip(x, -5, 5) out = paddle.clip(x, -5, 5)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -446,6 +461,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -446,6 +461,7 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = paddle.transpose(x, []) out = paddle.transpose(x, [])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -461,6 +477,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -461,6 +477,7 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = paddle.moveaxis(x, [], []) out = paddle.moveaxis(x, [], [])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -476,6 +493,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -476,6 +493,7 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64') index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index) out = paddle.gather(x, index)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -489,6 +507,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -489,6 +507,7 @@ class TestSundryAPI(unittest.TestCase):
) )
index = paddle.full([], 1, 'int64') index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index) out = paddle.gather(x, index)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, [3]) self.assertEqual(out.shape, [3])
...@@ -541,10 +560,18 @@ class TestSundryAPI(unittest.TestCase): ...@@ -541,10 +560,18 @@ class TestSundryAPI(unittest.TestCase):
x2.stop_gradient = False x2.stop_gradient = False
x3.stop_gradient = False x3.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
x3.retain_grads()
out1 = paddle.diagflat(x1, 1) out1 = paddle.diagflat(x1, 1)
out2 = paddle.diagflat(x2, -1) out2 = paddle.diagflat(x2, -1)
out3 = paddle.diagflat(x3, 0) out3 = paddle.diagflat(x3, 0)
out1.retain_grads()
out2.retain_grads()
out3.retain_grads()
out1.backward() out1.backward()
out2.backward() out2.backward()
out3.backward() out3.backward()
...@@ -592,7 +619,9 @@ class TestSundryAPI(unittest.TestCase): ...@@ -592,7 +619,9 @@ class TestSundryAPI(unittest.TestCase):
def test_scale(self): def test_scale(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.scale(x, scale=2.0, bias=1.0) out = paddle.scale(x, scale=2.0, bias=1.0)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -674,24 +703,28 @@ class TestSundryAPI(unittest.TestCase): ...@@ -674,24 +703,28 @@ class TestSundryAPI(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
out = paddle.reshape(x, []) out = paddle.reshape(x, [])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
out = paddle.reshape(x, [1]) out = paddle.reshape(x, [1])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1]) self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1]) self.assertEqual(out.grad.shape, [1])
out = paddle.reshape(x, [-1]) out = paddle.reshape(x, [-1])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1]) self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1]) self.assertEqual(out.grad.shape, [1])
out = paddle.reshape(x, [-1, 1]) out = paddle.reshape(x, [-1, 1])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1, 1]) self.assertEqual(out.shape, [1, 1])
...@@ -702,6 +735,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -702,6 +735,7 @@ class TestSundryAPI(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
out = paddle.reshape(x, []) out = paddle.reshape(x, [])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -709,6 +743,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -709,6 +743,7 @@ class TestSundryAPI(unittest.TestCase):
new_shape = paddle.to_tensor([1, 1, 1], "int32") new_shape = paddle.to_tensor([1, 1, 1], "int32")
out = paddle.reshape(x, new_shape) out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1, 1, 1]) self.assertEqual(out.shape, [1, 1, 1])
...@@ -716,6 +751,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -716,6 +751,7 @@ class TestSundryAPI(unittest.TestCase):
new_shape = paddle.to_tensor([-1], "int32") new_shape = paddle.to_tensor([-1], "int32")
out = paddle.reshape(x, new_shape) out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1]) self.assertEqual(out.shape, [1])
...@@ -723,6 +759,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -723,6 +759,7 @@ class TestSundryAPI(unittest.TestCase):
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")] new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape(x, new_shape) out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1, 1]) self.assertEqual(out.shape, [1, 1])
...@@ -765,9 +802,15 @@ class TestSundryAPI(unittest.TestCase): ...@@ -765,9 +802,15 @@ class TestSundryAPI(unittest.TestCase):
x1.stop_gradient = False x1.stop_gradient = False
x2.stop_gradient = False x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out1 = paddle.sort(x1, axis=-1) out1 = paddle.sort(x1, axis=-1)
out2 = paddle.sort(x2, axis=0) out2 = paddle.sort(x2, axis=0)
out1.retain_grads()
out2.retain_grads()
out1.backward() out1.backward()
out2.backward() out2.backward()
...@@ -787,10 +830,15 @@ class TestSundryAPI(unittest.TestCase): ...@@ -787,10 +830,15 @@ class TestSundryAPI(unittest.TestCase):
x2 = paddle.rand([]) x2 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
x2.stop_gradient = False x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out1 = paddle.argsort(x1, axis=-1) out1 = paddle.argsort(x1, axis=-1)
out2 = paddle.argsort(x2, axis=0) out2 = paddle.argsort(x2, axis=0)
out1.retain_grads()
out2.retain_grads()
out1.backward() out1.backward()
out2.backward() out2.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册