未验证 提交 9c3a35b9 编写于 作者: 姜永久 提交者: GitHub

rm flags retain grad in pybind (#49888)

* rm flags_retain grad in pybind

* retain grads for xpu test

* set retain grad for xpu

* rm flag

* lint

---------
Co-authored-by: Nwanghuancoder <wanghuan29@baidu.com>
上级 ac84dce9
......@@ -99,7 +99,6 @@ paddle::experimental::Tensor add_n_ad_func(
egr::EagerUtils::SetHistory(out_autograd_meta, grad_node);
}
grad_node->SetGradInMeta(out, 0);
egr::EagerUtils::CheckAndRetainGrad(out);
// Set TensorWrappers for Forward Outputs if needed
}
......
......@@ -162,7 +162,6 @@ paddle::experimental::Tensor conv2d_ad_func(
egr::EagerUtils::SetHistory(out_autograd_meta, grad_node);
}
grad_node->SetGradInMeta(out, 0);
egr::EagerUtils::CheckAndRetainGrad(out);
// Set TensorWrappers for Forward Outputs if needed
}
......
......@@ -159,8 +159,6 @@ Conv2dGradNodeFinal::operator()(
}
grad_node->SetGradInMeta(grad_input, 0);
grad_node->SetGradInMeta(grad_filter, 1);
egr::EagerUtils::CheckAndRetainGrad(grad_input);
egr::EagerUtils::CheckAndRetainGrad(grad_filter);
// Set TensorWrappers for Forward Outputs if needed
}
......
......@@ -432,7 +432,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_QKVBiasOut,
QKVBiasOut_accumulation_node);
QKVBiasOut_accumulation_node->SetGradInMeta(QKVBiasOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKVBiasOut);
grad_node->SetGradOutMeta(QKVBiasOut, 11);
}
......@@ -446,7 +445,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_SrcMaskOut,
SrcMaskOut_accumulation_node);
SrcMaskOut_accumulation_node->SetGradInMeta(SrcMaskOut, 0);
egr::EagerUtils::CheckAndRetainGrad(SrcMaskOut);
grad_node->SetGradOutMeta(SrcMaskOut, 12);
}
......@@ -473,7 +471,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_LnOut,
LnOut_accumulation_node);
LnOut_accumulation_node->SetGradInMeta(LnOut, 0);
egr::EagerUtils::CheckAndRetainGrad(LnOut);
grad_node->SetGradOutMeta(LnOut, 13);
}
if (LnMean.initialized()) {
......@@ -505,7 +502,6 @@ fused_attention_dygraph_function(
BiasDropoutResidualOut_accumulation_node);
BiasDropoutResidualOut_accumulation_node->SetGradInMeta(
BiasDropoutResidualOut, 0);
egr::EagerUtils::CheckAndRetainGrad(BiasDropoutResidualOut);
grad_node->SetGradOutMeta(BiasDropoutResidualOut, 14);
}
......@@ -524,17 +520,14 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_CacheKVOut, 18);
egr::EagerUtils::SetHistory(p_autograd_CacheKVOut, grad_node);
grad_node->SetGradInMeta(CacheKVOut, 18);
egr::EagerUtils::CheckAndRetainGrad(CacheKVOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 19);
egr::EagerUtils::SetHistory(p_autograd_Y, grad_node);
grad_node->SetGradInMeta(Y, 19);
egr::EagerUtils::CheckAndRetainGrad(Y);
auto QKVOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_QKVOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKVOut, QKVOut_accumulation_node);
QKVOut_accumulation_node->SetGradInMeta(QKVOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKVOut);
grad_node->SetGradOutMeta(QKVOut, 15);
auto QKTVOut_accumulation_node =
......@@ -543,7 +536,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_QKTVOut,
QKTVOut_accumulation_node);
QKTVOut_accumulation_node->SetGradInMeta(QKTVOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKTVOut);
grad_node->SetGradOutMeta(QKTVOut, 16);
auto TransposeOut2_accumulation_node =
......@@ -552,7 +544,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_TransposeOut2,
TransposeOut2_accumulation_node);
TransposeOut2_accumulation_node->SetGradInMeta(TransposeOut2, 0);
egr::EagerUtils::CheckAndRetainGrad(TransposeOut2);
grad_node->SetGradOutMeta(TransposeOut2, 17);
auto QKOut_accumulation_node =
......@@ -560,7 +551,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKOut, QKOut_accumulation_node);
QKOut_accumulation_node->SetGradInMeta(QKOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKOut);
grad_node->SetGradOutMeta(QKOut, 18);
auto SoftmaxOut_accumulation_node =
......@@ -569,7 +559,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_SoftmaxOut,
SoftmaxOut_accumulation_node);
SoftmaxOut_accumulation_node->SetGradInMeta(SoftmaxOut, 0);
egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut);
grad_node->SetGradOutMeta(SoftmaxOut, 19);
if (AttnDropoutOut.initialized()) {
......@@ -580,7 +569,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut,
AttnDropoutOut_accumulation_node);
AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0);
egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut);
grad_node->SetGradOutMeta(AttnDropoutOut, 20);
}
......@@ -590,7 +578,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_FMHAOut,
FMHAOut_accumulation_node);
FMHAOut_accumulation_node->SetGradInMeta(FMHAOut, 0);
egr::EagerUtils::CheckAndRetainGrad(FMHAOut);
grad_node->SetGradOutMeta(FMHAOut, 21);
auto OutLinearOut_accumulation_node =
......@@ -599,7 +586,6 @@ fused_attention_dygraph_function(
egr::EagerUtils::SetHistory(p_autograd_OutLinearOut,
OutLinearOut_accumulation_node);
OutLinearOut_accumulation_node->SetGradInMeta(OutLinearOut, 0);
egr::EagerUtils::CheckAndRetainGrad(OutLinearOut);
grad_node->SetGradOutMeta(OutLinearOut, 22);
}
}
......
......@@ -221,7 +221,6 @@ fused_bias_dropout_residual_layer_norm_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 4);
egr::EagerUtils::SetHistory(p_autograd_Y, grad_node);
grad_node->SetGradInMeta(Y, 4);
egr::EagerUtils::CheckAndRetainGrad(Y);
}
}
......
......@@ -363,7 +363,6 @@ fused_feedforward_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 0);
egr::EagerUtils::SetHistory(p_autograd_Out, grad_node);
grad_node->SetGradInMeta(Out, 0);
egr::EagerUtils::CheckAndRetainGrad(Out);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Dropout1Mask, 1);
grad_node->SetGradInMeta(Dropout1Mask, 1);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Dropout2Mask, 2);
......
......@@ -372,7 +372,6 @@ fused_gate_attention_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 7);
egr::EagerUtils::SetHistory(p_autograd_Out, grad_node);
grad_node->SetGradInMeta(Out, 7);
egr::EagerUtils::CheckAndRetainGrad(Out);
}
}
......
......@@ -120,7 +120,6 @@ paddle::experimental::Tensor fused_gemm_epilogue_dygraph_function(
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 0);
egr::EagerUtils::SetHistory(p_autograd_Out, grad_node);
grad_node->SetGradInMeta(Out, 0);
egr::EagerUtils::CheckAndRetainGrad(Out);
}
}
......
......@@ -1305,15 +1305,6 @@ static std::string GenerateGradNodeCreationContent(
paddle::string::Sprintf(SET_GRAD_IN_META_TEMPLATE,
LegalizeVarName(inplace_input_name),
output_position);
// Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str += paddle::string::Sprintf(
RETAIN_GRAD_TEMPLATE, LegalizeVarName(inplace_input_name));
}
} else {
const std::string& output_autograd_name =
"p_autograd_" + LegalizeVarName(output_name);
......@@ -1363,15 +1354,6 @@ static std::string GenerateGradNodeCreationContent(
LegalizeVarName(output_name),
output_position);
}
// Intermediate Tensor does not require CheckAndRetainGrad
if (!output.intermediate()) {
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
" egr::EagerUtils::CheckAndRetainGrad(%s);\n";
grad_node_creation_str += paddle::string::Sprintf(
RETAIN_GRAD_TEMPLATE, LegalizeVarName(output_name));
}
}
}
VLOG(6) << "Generated SetGradIn/OutMeta";
......
......@@ -280,8 +280,7 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{
{}
// SetGradOutMeta & SetEdges
{}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
// SetOutRank & SetHistory & SetGradInMeta
{}
{}
{}
......@@ -300,8 +299,7 @@ HIHGER_ORDER_DERIVATIVE_VALUE_TEMPLATE = """ if(trace_backward) {{
{}
// SetGradOutMeta & SetEdges
{}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
// SetOutRank & SetHistory & SetGradInMeta
{}
{}
{}
......@@ -987,7 +985,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank_list = []
set_history_list = []
set_grad_in_meta_list = []
set_retain_grad_list = []
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
......@@ -1002,19 +999,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_grad_in_meta = (
f"{indent}grad_node->SetGradInMeta({name}, {pos});"
)
set_retain_grad = (
f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
)
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list)
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
......@@ -1029,7 +1021,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank_str,
set_history_str,
set_grad_in_meta_str,
set_retain_grad_str,
set_output_tensor_wrappers_str,
)
else:
......@@ -1043,7 +1034,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank_str,
set_history_str,
set_grad_in_meta_str,
set_retain_grad_str,
set_output_tensor_wrappers_str,
)
)
......
......@@ -310,7 +310,6 @@ RunCustomOpNode::operator()(
egr::EagerUtils::SetOutRankWithSlot(&(outs_auto_grad_metas[i]), i);
egr::EagerUtils::SetHistory(&(outs_auto_grad_metas[i]), grad_node);
grad_node->SetGradInMeta(out_tensors, i);
egr::EagerUtils::CheckAndRetainGrad(out_tensors);
}
// Prepare Grad inputs with fwd outputs
......
......@@ -122,6 +122,5 @@ inline void run_program_ad_func(
// Set History for output set current Grad Node for
egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node);
egr::EagerUtils::CheckAndRetainGrad(deref_out);
}
}
......@@ -27,10 +27,6 @@
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/variable.h"
PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor,
false,
"retain grad for all tensor");
namespace egr {
/**
* Implementation of Eager Utils.
......@@ -409,35 +405,6 @@ std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper(
}
return ret;
}
// TODO(jiabin): remove all this when we fix all test using tmp grad
void EagerUtils::CheckAndRetainGrad(
const paddle::experimental::Tensor& tensor) {
VLOG(6) << "Check RetainGradForTensor: " << tensor.name();
if (FLAGS_retain_grad_for_all_tensor) {
VLOG(6) << "RetainGradForTensor: " << tensor.name();
egr::egr_utils_api::RetainGradForTensor(tensor);
}
}
void EagerUtils::CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor>& tensors) {
if (FLAGS_retain_grad_for_all_tensor) {
for (auto& tensor : tensors) {
VLOG(6) << "RetainGradForTensor: " << tensor.name();
egr::egr_utils_api::RetainGradForTensor(tensor);
}
}
}
void EagerUtils::CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor*>& tensors) {
if (FLAGS_retain_grad_for_all_tensor) {
for (auto& tensor : tensors) {
VLOG(6) << "RetainGradForTensor: " << tensor->name();
egr::egr_utils_api::RetainGradForTensor(*tensor);
}
}
}
std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor) {
......
......@@ -223,14 +223,6 @@ class EagerUtils {
const std::vector<paddle::experimental::Tensor*>& out_var,
std::vector<paddle::experimental::Tensor>* result);
// end Intermidate needed.
static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor);
static void CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor>& tensors);
static void CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor*>& tensors);
static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor);
......
......@@ -575,7 +575,6 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
egr::EagerUtils::SetOutRankWithSlot(&(outs_auto_grad_metas[i]), i);
egr::EagerUtils::SetHistory(&(outs_auto_grad_metas[i]), grad_node);
grad_node->SetGradInMeta(out_tensors, i);
egr::EagerUtils::CheckAndRetainGrad(out_tensors);
}
// Prepare Grad inputs with fwd outputs
......
......@@ -432,12 +432,10 @@ PyObject* pylayer_method_apply(PyObject* cls,
for (auto t : outputs_tensor[i]) {
grad_node->SetGradInMeta(*t, i);
}
egr::EagerUtils::CheckAndRetainGrad(outputs_tensor[i]);
} else {
egr::EagerUtils::SetOutRankWithSlot(outputs_autograd_meta[i][0], i);
egr::EagerUtils::SetHistory(outputs_autograd_meta[i][0], grad_node);
grad_node->SetGradInMeta(*outputs_tensor[i][0], i);
egr::EagerUtils::CheckAndRetainGrad(*outputs_tensor[i][0]);
}
}
VLOG(6) << "PyLayer construct backward node finish...";
......
......@@ -152,7 +152,6 @@ class TestFlipDoubleGradCheck(unittest.TestCase):
gradient_checker.double_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph(
self.flip_wrapper, [data], out, x_init=[data_arr], place=place
)
......@@ -184,7 +183,6 @@ class TestFlipTripleGradCheck(unittest.TestCase):
gradient_checker.triple_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(
self.flip_wrapper, [data], out, x_init=[data_arr], place=place
)
......
......@@ -20,7 +20,6 @@ import paddle
import paddle.nn.functional as F
paddle.set_device('xpu')
paddle.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
unary_api_list = [
paddle.nn.functional.elu,
......@@ -102,6 +101,7 @@ class TestUnaryAPI(unittest.TestCase):
x = paddle.rand([])
x.stop_gradient = False
out = api(x)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
......@@ -147,6 +147,7 @@ class TestReduceAPI(unittest.TestCase):
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out.retain_grads()
out.backward()
......@@ -201,12 +202,15 @@ class TestBinaryAPI(unittest.TestCase):
y = paddle.rand([])
x.stop_gradient = False
y.stop_gradient = False
x.retain_grads()
y.retain_grads()
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
......@@ -228,6 +232,7 @@ class TestBinaryAPI(unittest.TestCase):
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [2, 3, 4])
......@@ -243,12 +248,15 @@ class TestBinaryAPI(unittest.TestCase):
y = paddle.rand([2, 3, 4])
x.stop_gradient = False
y.stop_gradient = False
x.retain_grads()
y.retain_grads()
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
......@@ -265,6 +273,7 @@ class TestBinaryAPI(unittest.TestCase):
y = 0.5
if isinstance(api, dict):
out = getattr(paddle.Tensor, api['cls_method'])(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
......@@ -381,7 +390,9 @@ class TestSundryAPI(unittest.TestCase):
def test_pow_factor(self):
x = paddle.rand([])
x.stop_gradient = False
x.retain_grads()
out = paddle.pow(x, 2.0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
......@@ -391,7 +402,9 @@ class TestSundryAPI(unittest.TestCase):
def test_cast(self):
x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False
x.retain_grads()
out = paddle.cast(x, 'int32')
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
......@@ -401,7 +414,9 @@ class TestSundryAPI(unittest.TestCase):
def test_clip(self):
x = paddle.uniform([], None, -10, 10)
x.stop_gradient = False
x.retain_grads()
out = paddle.clip(x, -5, 5)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
......@@ -446,6 +461,7 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.transpose(x, [])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
......@@ -461,6 +477,7 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.moveaxis(x, [], [])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
......@@ -476,6 +493,7 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
......@@ -489,6 +507,7 @@ class TestSundryAPI(unittest.TestCase):
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [3])
......@@ -541,10 +560,18 @@ class TestSundryAPI(unittest.TestCase):
x2.stop_gradient = False
x3.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
x3.retain_grads()
out1 = paddle.diagflat(x1, 1)
out2 = paddle.diagflat(x2, -1)
out3 = paddle.diagflat(x3, 0)
out1.retain_grads()
out2.retain_grads()
out3.retain_grads()
out1.backward()
out2.backward()
out3.backward()
......@@ -592,7 +619,9 @@ class TestSundryAPI(unittest.TestCase):
def test_scale(self):
x = paddle.rand([])
x.stop_gradient = False
x.retain_grads()
out = paddle.scale(x, scale=2.0, bias=1.0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
......@@ -674,24 +703,28 @@ class TestSundryAPI(unittest.TestCase):
x.stop_gradient = False
out = paddle.reshape(x, [])
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
out = paddle.reshape(x, [1])
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])
out = paddle.reshape(x, [-1])
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])
out = paddle.reshape(x, [-1, 1])
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1, 1])
......@@ -702,6 +735,7 @@ class TestSundryAPI(unittest.TestCase):
x.stop_gradient = False
out = paddle.reshape(x, [])
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [])
......@@ -709,6 +743,7 @@ class TestSundryAPI(unittest.TestCase):
new_shape = paddle.to_tensor([1, 1, 1], "int32")
out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1, 1, 1])
......@@ -716,6 +751,7 @@ class TestSundryAPI(unittest.TestCase):
new_shape = paddle.to_tensor([-1], "int32")
out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1])
......@@ -723,6 +759,7 @@ class TestSundryAPI(unittest.TestCase):
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1, 1])
......@@ -765,9 +802,15 @@ class TestSundryAPI(unittest.TestCase):
x1.stop_gradient = False
x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out1 = paddle.sort(x1, axis=-1)
out2 = paddle.sort(x2, axis=0)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
......@@ -787,10 +830,15 @@ class TestSundryAPI(unittest.TestCase):
x2 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out1 = paddle.argsort(x1, axis=-1)
out2 = paddle.argsort(x2, axis=0)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册