diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 52eccfba6978d9a537ccc28100ee0740af75a865..dac57c2e22c750122712c378dc553e8e74909057 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -53,11 +53,6 @@ static std::shared_ptr EmptyOp() { static std::shared_ptr BackwardImpl( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { - // struct OpIdentity { - // size_t local_op_id; - // size_t op_output_offset; - // }; - if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { return EmptyOp(); @@ -87,9 +82,7 @@ static std::shared_ptr BackwardImpl( for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it) { auto fwd = *it; - // for (auto& fwd : forwardNet.ops_) { - // auto bwd = Backward(*fwd, no_grad_names); - auto bwd = Backward(*fwd, no_grad_names); + auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); net->AddOp(bwd); for (size_t i = 0; i < bwd->outputs_.size(); ++i) { dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); @@ -138,6 +131,7 @@ static std::shared_ptr BackwardImpl( {grad_input}, {})); } } + for (std::string& grad_output : grad_op->outputs_) { if (no_grad_names.count(grad_output)) { grad_output = OperatorBase::EMPTY_VAR_NAME(); diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 63194e78fcfe9e0118e9f85a2e88969d3478e78d..7185872d0a0c99247776fdc8727157a57b6e1ee4 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -230,8 +230,9 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); - ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), - first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ( + f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); } TEST(Backward, net_shared_weight) { @@ -244,14 +245,13 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - LOG(INFO) << bwd_net->DebugString(); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); } TEST(Backward, op_register_grad_not_for_network) { - auto fwd = - f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, - {{"temporary_index", std::vector{1}}}); + auto fwd = f::OpRegistry::CreateOp( + "fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"}, + {{"temporary_index", std::vector{1}}}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } @@ -299,11 +299,9 @@ TEST(Backward, op_part_of_output_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); - ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); - ASSERT_EQ(net->ops_.size(), 1UL); + ASSERT_TRUE(!backward->IsNetOp()); - auto &grad_mul = *net->ops_[0]; + auto &grad_mul = *backward; ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL); @@ -320,10 +318,13 @@ TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) { f::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {})); - net.CompleteAddOp(false); + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, + {"mul_out1", "add_out1", "out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, + {"mul_out2", "tmp_out2", "out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, + {"mul_out3", "tmp_out3", "out3"}, {})); + net.CompleteAddOp(); auto backward = f::Backward(net, {"out2"}); ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get());