diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 25ebcefa03ff657b6fc41e3be05c710606add194..472a671e470c5411750d56f91721d41c4461e3a8 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -46,6 +46,7 @@ static std::vector InSetIdx( static std::shared_ptr EmptyOp() { auto net_op = std::make_shared(); + net_op->type_ = "@EMPTY_OP@"; net_op->CompleteAddOp(); return net_op; } @@ -140,7 +141,7 @@ static std::shared_ptr BackwardImpl( [](const Pos& l, const Pos& r) { return l.first > r.first; }); for (auto& pos : insert_position) { - net->InsertOp(pos.first, pos.second); + net->InsertOp(pos.first + 1, pos.second); } } else { @@ -167,7 +168,7 @@ static std::shared_ptr BackwardImpl( } net->AddOp(grad_op); } - + net->type_ = "@GENERATED_BACKWARD@"; net->CompleteAddOp(); return net; } diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 420cc65fefd14847ec64a6d1b7f9829f90de3b06..00c11563af823b4e74d656b432dbb553bc115e13 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -269,15 +269,14 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); + ASSERT_EQ("add", bwd_net->ops_[2]->type_); } TEST(Backward, op_register_grad_not_for_network) { - // auto fwd = - // f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, - // {{"temporary_index", std::vector{1}}}); + auto fwd = f::OpRegistry::CreateOp( + "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"}, + {{"temporary_index", std::vector{0, 1}}}); - auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } @@ -350,13 +349,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { {"mul_out3", "tmp_out3", "out3"}, {})); net.CompleteAddOp(); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); + LOG(INFO) << backward->DebugString(); + ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get()); ASSERT_EQ(bwd_net->ops_.size(), 3UL); - EXPECT_EQ(bwd_net->ops_[0]->type_, ""); - EXPECT_EQ(bwd_net->ops_[1]->type_, ""); - EXPECT_EQ(bwd_net->ops_[2]->type_, ""); - auto &grad_fc = *bwd_net->ops_[0]; EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL); EXPECT_EQ(grad_fc.outputs_.size(), 3UL);