未验证 提交 15087552 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad] Enabled test_imperative_triple_grad test cases under eager_mode (#41612)

* [DoubleGrad] Enabled double grad test cases in eager_mode for test_imperative_double_grad

* Fixed elementwise issue

* Addressed CI failures

* [DoubleGrad] Enabled test_imperative_triple_grad test cases under eager_mode

* Fixed minor issues
上级 e53d1837
...@@ -2011,8 +2011,7 @@ static std::string GenerateSingleOpBase( ...@@ -2011,8 +2011,7 @@ static std::string GenerateSingleOpBase(
"egr::EagerUtils::TrySyncToVars(egr::EagerUtils::" "egr::EagerUtils::TrySyncToVars(egr::EagerUtils::"
"RecoverTensorWrapper(" "RecoverTensorWrapper("
"&" "&"
"this->%s, " "this->%s)) },";
"nullptr)) },";
ins_contents_str += ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name); grad_input_name, struct_fwd_input_name);
...@@ -2058,15 +2057,15 @@ static std::string GenerateSingleOpBase( ...@@ -2058,15 +2057,15 @@ static std::string GenerateSingleOpBase(
const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE = const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE =
" if(this->%s.size() > 0) %s[\"%s\"] = " " if(this->%s.size() > 0) %s[\"%s\"] = "
"egr::EagerUtils::TrySyncToVars(egr::EagerUtils::" "egr::EagerUtils::TrySyncToVars(egr::EagerUtils::"
"RecoverTensorWrapper(&this->%s, nullptr));\n"; "RecoverTensorWrapper(&this->%s));\n";
generated_grad_function_body += paddle::string::Sprintf( generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, struct_fwd_input_name, DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, struct_fwd_input_name,
ins_name, grad_input_name, struct_fwd_input_name); ins_name, grad_input_name, struct_fwd_input_name);
} else { } else {
const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE = const char* DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE =
" auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s, " " auto %s = egr::EagerUtils::RecoverTensorWrapper(&this->%s);\n"
"nullptr);\n if(%s.initialized()) %s[\"%s\"] = " " if(%s.initialized()) %s[\"%s\"] = "
"egr::EagerUtils::TrySyncToVars(%s);\n"; " egr::EagerUtils::TrySyncToVars(%s);\n";
generated_grad_function_body += paddle::string::Sprintf( generated_grad_function_body += paddle::string::Sprintf(
DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, grad_input_name, DISPENSABLE_GRAD_INS_FWD_CONTENT_TEMPLATE, grad_input_name,
struct_fwd_input_name, grad_input_name, ins_name, grad_input_name, struct_fwd_input_name, grad_input_name, ins_name, grad_input_name,
......
...@@ -23,7 +23,8 @@ import os ...@@ -23,7 +23,8 @@ import os
######################## ########################
ops_to_fill_zero_for_empty_grads = set([ ops_to_fill_zero_for_empty_grads = set([
"split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad", "split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad",
"sigmoid_triple_grad, add_double_grad" "sigmoid_double_grad", "sigmoid_triple_grad", "add_double_grad",
"add_triple_grad"
]) ])
# For API dispatch used at python-level # For API dispatch used at python-level
......
...@@ -236,7 +236,7 @@ FORWARD_BODY_TEMPLATE = \ ...@@ -236,7 +236,7 @@ FORWARD_BODY_TEMPLATE = \
{} {}
// SetAttributes // SetAttributes
{} {}
// SetTensorWrappers // Set TensorWrappers for Forward Inputs
{} {}
// SetGradOutMeta & SetEdges // SetGradOutMeta & SetEdges
{} {}
...@@ -245,6 +245,8 @@ FORWARD_BODY_TEMPLATE = \ ...@@ -245,6 +245,8 @@ FORWARD_BODY_TEMPLATE = \
{} {}
{} {}
{} {}
{}
// Set TensorWrappers for Forward Outputs
{} {}
}} }}
""" """
...@@ -720,7 +722,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -720,7 +722,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_attributes_str = "\n".join(set_attributes_list) set_attributes_str = "\n".join(set_attributes_list)
# SetTensorWrappers # SetTensorWrappers
set_tensor_wrappers_list = [] set_input_tensor_wrappers_list = []
set_output_tensor_wrappers_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys()) num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (atype, is_fwd_input, for name, (atype, is_fwd_input,
pos) in backward_forward_inputs_map.items(): pos) in backward_forward_inputs_map.items():
...@@ -732,6 +735,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -732,6 +735,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
else: else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, {need_input_data});" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, {need_input_data});"
set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: else:
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
# Aligned with forward output position # Aligned with forward output position
...@@ -743,8 +747,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -743,8 +747,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);" set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
else: else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers) set_output_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) set_input_tensor_wrappers_str = "\n".join(
set_input_tensor_wrappers_list)
set_output_tensor_wrappers_str = "\n".join(
set_output_tensor_wrappers_list)
# SetGradOutMeta & SetEdges # SetGradOutMeta & SetEdges
set_grad_out_meta_list = [] set_grad_out_meta_list = []
...@@ -801,9 +808,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -801,9 +808,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.node_creation_str = FORWARD_BODY_TEMPLATE.format( self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str, pass_stop_gradient_args_str, node_creation_event_str, pass_stop_gradient_args_str,
node_construction_str, set_attributes_str, set_tensor_wrappers_str, node_construction_str, set_attributes_str,
set_grad_out_meta_str, set_edges_str, set_out_rank_str, set_input_tensor_wrappers_str, set_grad_out_meta_str, set_edges_str,
set_history_str, set_grad_in_meta_str, set_retain_grad_str) set_out_rank_str, set_history_str, set_grad_in_meta_str,
set_retain_grad_str, set_output_tensor_wrappers_str)
def run(self): def run(self):
# Basic Validation Check # Basic Validation Check
...@@ -1296,7 +1304,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1296,7 +1304,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
transformed_tensor_name = self.TransformToNextGradName(name) transformed_tensor_name = self.TransformToNextGradName(name)
is_optional = (name in self.optional_inputs) is_optional = (name in self.optional_inputs)
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});"
if is_optional: if is_optional:
tensor_wrapper_recover_str += "\n" + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format( tensor_wrapper_recover_str += "\n" + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name, transformed_tensor_name, transformed_tensor_name,
......
...@@ -731,16 +731,6 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -731,16 +731,6 @@ std::vector<paddle::experimental::Tensor> RunBackward(
continue; continue;
} }
auto* next_node = next_node_shared.get();
if (!node_input_buffers_dict.count(next_node)) {
const auto& input_meta = next_node->InputMeta();
auto grad_tensor_holder =
std::make_unique<GradTensorHolder>(input_meta);
VLOG(6) << "Construct GradTensorHolder for grad node: "
<< next_node->name();
node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
}
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
j, grad_output_tensors[i].size(), j, grad_output_tensors[i].size(),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -760,8 +750,19 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -760,8 +750,19 @@ std::vector<paddle::experimental::Tensor> RunBackward(
<< ", rank: " << j << ", rank: " << j
<< " 's name is: " << grad_output_tensor.name(); << " 's name is: " << grad_output_tensor.name();
auto* next_node = next_node_shared.get();
if (!node_input_buffers_dict.count(next_node)) {
const auto& input_meta = next_node->InputMeta();
auto grad_tensor_holder =
std::make_unique<GradTensorHolder>(input_meta);
VLOG(6) << "Construct GradTensorHolder for grad node: "
<< next_node->name();
node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
}
VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first
<< ", rank: " << edge_rank.second; << ", rank: " << edge_rank.second;
node_input_buffers_dict[next_node]->add( node_input_buffers_dict[next_node]->add(
edge_rank.first, edge_rank.second, grad_output_tensor); edge_rank.first, edge_rank.second, grad_output_tensor);
......
...@@ -59,7 +59,7 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -59,7 +59,7 @@ class RunCustomOpNode : public GradNodeBase {
std::vector<egr::TensorWrapper>* fwd_var) { std::vector<egr::TensorWrapper>* fwd_var) {
std::vector<paddle::experimental::Tensor> res; std::vector<paddle::experimental::Tensor> res;
for (size_t i = 0; i < fwd_var->size(); i++) { for (size_t i = 0; i < fwd_var->size(); i++) {
res.emplace_back(fwd_var->at(i).recover(nullptr)); res.emplace_back(fwd_var->at(i).recover());
} }
return res; return res;
} }
......
...@@ -61,6 +61,10 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) { ...@@ -61,6 +61,10 @@ void GradNodeBase::AddEdges(std::vector<AutogradMeta*>* metas, size_t slot_id) {
if (!node || !node.get()) { if (!node || !node.get()) {
meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta)); meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
} }
VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
<< this->name() << " (addr: " << this << ") "
<< " to " << meta->GetMutableGradNode()->name()
<< " (addr: " << meta->GetMutableGradNode().get() << ")";
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
...@@ -84,7 +88,9 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) { ...@@ -84,7 +88,9 @@ void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) {
meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta)); meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
} }
VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from " VLOG(6) << "Add Edges for slot: " << slot_id << ", the Edge is from "
<< this->name() << " to " << meta->GetMutableGradNode()->name(); << this->name() << " (addr: " << this << ") "
<< " to " << meta->GetMutableGradNode()->name()
<< " (addr: " << meta->GetMutableGradNode().get() << ")";
adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(),
meta->OutRankInfo()); meta->OutRankInfo());
......
...@@ -110,6 +110,7 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, ...@@ -110,6 +110,7 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
"got tensor: %s is empty please check you network " "got tensor: %s is empty please check you network "
"and make sure it creates grads.", "and make sure it creates grads.",
t.name())); t.name()));
if (t.is_dense_tensor()) { if (t.is_dense_tensor()) {
if (buffer_tensor.is_dense_tensor()) { if (buffer_tensor.is_dense_tensor()) {
buffer_tensor = add_final_state_dygraph_function(t, buffer_tensor); buffer_tensor = add_final_state_dygraph_function(t, buffer_tensor);
......
...@@ -77,16 +77,17 @@ class TensorWrapper { ...@@ -77,16 +77,17 @@ class TensorWrapper {
intermidiate_tensor_.set_name(tensor.name() + "@Saved"); intermidiate_tensor_.set_name(tensor.name() + "@Saved");
// If an output is marked "intermedaite", we won't create auto* tensor_autograd_meta = EagerUtils::nullable_autograd_meta(tensor);
// autograd_meta for it. if (tensor_autograd_meta) {
// In that case, simply skip OutRankInfo Copy auto autograd_meta = std::make_shared<AutogradMeta>(
if (EagerUtils::nullable_autograd_meta(tensor)) { Edge(nullptr, EagerUtils::OutRankInfo(tensor)));
out_rank_info_ = EagerUtils::OutRankInfo(tensor); autograd_meta->SetStopGradient(tensor_autograd_meta->StopGradient());
intermidiate_tensor_.set_autograd_meta(autograd_meta);
weak_grad_node_ = tensor_autograd_meta->GetMutableGradNode();
} }
} }
paddle::experimental::Tensor recover( paddle::experimental::Tensor recover() {
const std::shared_ptr<GradNodeBase>& grad_node) {
VLOG(6) << "Recover tensor: " << intermidiate_tensor_.name() VLOG(6) << "Recover tensor: " << intermidiate_tensor_.name()
<< " for wrapper"; << " for wrapper";
if (!intermidiate_tensor_.defined()) { if (!intermidiate_tensor_.defined()) {
...@@ -99,9 +100,20 @@ class TensorWrapper { ...@@ -99,9 +100,20 @@ class TensorWrapper {
// if it's full_reserved just return the full copy of tensor // if it's full_reserved just return the full copy of tensor
paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;
if (!full_reserved_) { if (!full_reserved_) {
std::shared_ptr<GradNodeBase> new_grad_node = grad_node; std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock();
auto p_ab_autograd_meta = if (new_grad_node) {
std::make_shared<AutogradMeta>(Edge(new_grad_node, out_rank_info_)); VLOG(3) << "Recovered TensorWrapper with GradNode "
<< new_grad_node->name() << " addr: " << new_grad_node.get();
} else {
VLOG(3) << "Recovered TensorWrapper with Empth GradNode";
}
auto* intermediate_autograd_meta =
EagerUtils::unsafe_autograd_meta(intermidiate_tensor_);
auto p_ab_autograd_meta = std::make_shared<AutogradMeta>(
Edge(new_grad_node, intermediate_autograd_meta->OutRankInfo()));
p_ab_autograd_meta->SetStopGradient(
intermediate_autograd_meta->StopGradient());
recovered_tensor.set_autograd_meta( recovered_tensor.set_autograd_meta(
std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>( std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
p_ab_autograd_meta)); p_ab_autograd_meta));
...@@ -149,8 +161,8 @@ class TensorWrapper { ...@@ -149,8 +161,8 @@ class TensorWrapper {
private: private:
bool full_reserved_ = false; bool full_reserved_ = false;
bool no_need_buffer_ = false; bool no_need_buffer_ = false;
std::pair<size_t, size_t> out_rank_info_;
paddle::experimental::Tensor intermidiate_tensor_; paddle::experimental::Tensor intermidiate_tensor_;
std::weak_ptr<egr::GradNodeBase> weak_grad_node_;
uint32_t inplace_version_snapshot_ = 0; uint32_t inplace_version_snapshot_ = 0;
}; };
} // namespace egr } // namespace egr
...@@ -41,7 +41,7 @@ TEST(TensorWrapper, Basic) { ...@@ -41,7 +41,7 @@ TEST(TensorWrapper, Basic) {
et1.set_autograd_meta(auto_grad0); et1.set_autograd_meta(auto_grad0);
et1.set_name("et1"); et1.set_name("et1");
auto tw0 = egr::TensorWrapper(et1, true); auto tw0 = egr::TensorWrapper(et1, true);
auto recover_et1 = tw0.recover(std::make_shared<eager_test::GradTestNode>()); auto recover_et1 = tw0.recover();
CHECK_EQ(recover_et1.name(), std::string("et1")); CHECK_EQ(recover_et1.name(), std::string("et1"));
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).first, CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et1).first,
egr::EagerUtils::OutRankInfo(et1).first); egr::EagerUtils::OutRankInfo(et1).first);
...@@ -67,7 +67,7 @@ TEST(TensorWrapper, Basic) { ...@@ -67,7 +67,7 @@ TEST(TensorWrapper, Basic) {
auto auto_grad1 = std::make_shared<egr::AutogradMeta>(edge1); auto auto_grad1 = std::make_shared<egr::AutogradMeta>(edge1);
et2.set_autograd_meta(auto_grad1); et2.set_autograd_meta(auto_grad1);
auto tw1 = egr::TensorWrapper(et2, false); auto tw1 = egr::TensorWrapper(et2, false);
auto recover_et2 = tw1.recover(grad_test_node1); auto recover_et2 = tw1.recover();
CHECK_EQ(recover_et2.name(), std::string("et2@Saved")); CHECK_EQ(recover_et2.name(), std::string("et2@Saved"));
CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).first, CHECK_EQ(egr::EagerUtils::OutRankInfo(recover_et2).first,
egr::EagerUtils::OutRankInfo(et2).first); egr::EagerUtils::OutRankInfo(et2).first);
...@@ -76,7 +76,5 @@ TEST(TensorWrapper, Basic) { ...@@ -76,7 +76,5 @@ TEST(TensorWrapper, Basic) {
// Test Raw recover // Test Raw recover
paddle::experimental::Tensor et3; paddle::experimental::Tensor et3;
auto tw2 = egr::TensorWrapper(et3, true); auto tw2 = egr::TensorWrapper(et3, true);
CHECK( CHECK(tw2.recover().initialized() == false);
tw2.recover(std::make_shared<eager_test::GradTestNode>()).initialized() ==
false);
} }
...@@ -360,16 +360,15 @@ void EagerUtils::Output2Result( ...@@ -360,16 +360,15 @@ void EagerUtils::Output2Result(
} }
paddle::experimental::Tensor EagerUtils::RecoverTensorWrapper( paddle::experimental::Tensor EagerUtils::RecoverTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node) { TensorWrapper* tw) {
return tw->recover(grad_node); return tw->recover();
} }
std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper( std::vector<paddle::experimental::Tensor> EagerUtils::RecoverTensorWrapper(
std::vector<TensorWrapper>* tw, std::vector<TensorWrapper>* tw) {
const std::shared_ptr<GradNodeBase>& grad_node) {
std::vector<paddle::experimental::Tensor> ret; std::vector<paddle::experimental::Tensor> ret;
for (auto& t : *tw) { for (auto& t : *tw) {
ret.emplace_back(t.recover(grad_node)); ret.emplace_back(t.recover());
} }
return ret; return ret;
} }
......
...@@ -174,11 +174,9 @@ class EagerUtils { ...@@ -174,11 +174,9 @@ class EagerUtils {
const std::shared_ptr<EagerVariable>& view_output_var); const std::shared_ptr<EagerVariable>& view_output_var);
// TensorWrapper Utils // TensorWrapper Utils
static paddle::experimental::Tensor RecoverTensorWrapper( static paddle::experimental::Tensor RecoverTensorWrapper(TensorWrapper* tw);
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node);
static std::vector<paddle::experimental::Tensor> RecoverTensorWrapper( static std::vector<paddle::experimental::Tensor> RecoverTensorWrapper(
std::vector<TensorWrapper>* tw, std::vector<TensorWrapper>* tw);
const std::shared_ptr<GradNodeBase>& grad_node);
// Intermidate needed remove this once we don't need legacy // Intermidate needed remove this once we don't need legacy
// Inner Method // Inner Method
......
...@@ -209,7 +209,9 @@ class TestDygraphTripleGrad(TestCase): ...@@ -209,7 +209,9 @@ class TestDygraphTripleGrad(TestCase):
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected)) self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self): def test_all_cases(self):
if _in_legacy_dygraph(): self.func_exception()
self.func_example_with_gradient_and_create_graph()
with _test_eager_guard():
self.func_exception() self.func_exception()
self.func_example_with_gradient_and_create_graph() self.func_example_with_gradient_and_create_graph()
...@@ -296,7 +298,8 @@ class TestDygraphTripleGradBradcastCase(TestCase): ...@@ -296,7 +298,8 @@ class TestDygraphTripleGradBradcastCase(TestCase):
self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected)) self.assertTrue(np.allclose(dddx_grad_actual, dddx_expected))
def test_all_cases(self): def test_all_cases(self):
if _in_legacy_dygraph(): self.func_example_with_gradient_and_create_graph()
with _test_eager_guard():
self.func_example_with_gradient_and_create_graph() self.func_example_with_gradient_and_create_graph()
......
...@@ -1458,7 +1458,7 @@ ...@@ -1458,7 +1458,7 @@
func : GeneralTernaryGradInferMeta func : GeneralTernaryGradInferMeta
param : [out, fwd_grad_out, grad_grad_x] param : [out, fwd_grad_out, grad_grad_x]
kernel : kernel :
func : sigmoid_double_grad func : sigmoid_triple_grad
- backward_api : silu_grad - backward_api : silu_grad
forward : silu (Tensor x) -> Tensor(out) forward : silu (Tensor x) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册