diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index b6c46302b1f39b79f60e6f6accb4a3e6becb001f..2940af7e3a15f7acf1215f6f3f5bdb0cad7a7386 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -53,11 +53,6 @@ static std::shared_ptr EmptyOp() { static std::shared_ptr BackwardImpl( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { - // struct OpIdentity { - // size_t local_op_id; - // size_t op_output_offset; - // }; - if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); @@ -87,7 +82,7 @@ static std::shared_ptr BackwardImpl( for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it) { auto fwd = *it; - auto bwd = Backward(*fwd, no_grad_names); + auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); net->AddOp(bwd); for (size_t i = 0; i < bwd->outputs_.size(); ++i) { dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); @@ -136,6 +131,7 @@ static std::shared_ptr BackwardImpl( {grad_input}, {})); } } + for (std::string& grad_output : grad_op->outputs_) { if (no_grad_names.count(grad_output)) { grad_output = OperatorBase::EMPTY_VAR_NAME(); diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index a481cb1b2a7ccfd41e5e07265b4809c53649e634..69faee9fb71516c8a13a75b493dc193669623800 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -251,8 +251,9 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); - ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), - first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ( + f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); } TEST(Backward, net_shared_weight) { @@ -265,14 +266,13 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - LOG(INFO) << bwd_net->DebugString(); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); } TEST(Backward, op_register_grad_not_for_network) { - auto fwd = - f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, - {{"temporary_index", std::vector{1}}}); + auto fwd = f::OpRegistry::CreateOp( + "fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"}, + {{"temporary_index", std::vector{1}}}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } @@ -320,11 +320,9 @@ TEST(Backward, op_part_of_output_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); - ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); - ASSERT_EQ(net->ops_.size(), 1UL); + ASSERT_TRUE(!backward->IsNetOp()); - auto &grad_mul = *net->ops_[0]; + auto &grad_mul = *backward; ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL); @@ -341,10 +339,13 @@ TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) { f::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {})); - net.CompleteAddOp(false); + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, + {"mul_out1", "add_out1", "out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, + {"mul_out2", "tmp_out2", "out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, + {"mul_out3", "tmp_out3", "out3"}, {})); + net.CompleteAddOp(); auto backward = f::Backward(net, {"out2"}); ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get());