diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 85184e02b6e1a2700a337fce690370d1c5a0346f..d09892f81bea34415d454b017258fd2a0d4575db 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -133,6 +133,7 @@ TEST(GradOpDescBuilder, MutiInOut) { f::OpDescBind *grad_op = new f::OpDescBind(); f::CompleteGradOpDesc(forw_op, grad_op); + EXPECT_EQ(grad_op->Type(), "mult_io_grad"); ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL); EXPECT_EQ(grad_op->Input("In1"), std::vector({"in1"})); EXPECT_EQ(grad_op->Input("In2_mult"), @@ -156,4 +157,45 @@ TEST(GradOpDescBuilder, MutiInOut) { f::GradVarName("in2_3")})); EXPECT_EQ(grad_op->Output(f::GradVarName("In3")), std::vector({f::GradVarName("in3")})); + delete forw_op; + delete grad_op; } + +TEST(GradOpDescBuilder, IOIgnoredInGradient) { + f::OpDescBind *forw_op = new f::OpDescBind(); + forw_op->SetType("io_ignored"); + forw_op->SetInput("In1", {"in1"}); + forw_op->SetInput("In2_mult", {"in2_1", "in2_2"}); + forw_op->SetInput("In3_mult", {"in3_1", "in3_2"}); + forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"}); + forw_op->SetOutput("Out2", {"out2"}); + + f::OpDescBind *grad_op = new f::OpDescBind(); + f::CompleteGradOpDesc(forw_op, grad_op); + + EXPECT_EQ(grad_op->Type(), "io_ignored_grad"); + // 'In2' and 'Out2' are ignored in gradient calculating + ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL); + EXPECT_EQ(grad_op->Input("In1"), std::vector({"in1"})); + EXPECT_EQ(grad_op->Input("In3_mult"), + std::vector({"in3_1", "in3_2"})); + EXPECT_EQ(grad_op->Input("Out1_mult"), + std::vector({"out1_1", "out1_2"})); + EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")), + std::vector( + {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); + EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")), + std::vector({f::GradVarName("out2")})); + + ASSERT_EQ(grad_op->OutputNames().size(), 3UL); + EXPECT_EQ(grad_op->Output(f::GradVarName("In1")), + std::vector({f::GradVarName("in1")})); + EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")), + std::vector( + {f::GradVarName("in2_1"), f::GradVarName("in2_2")})); + EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")), + std::vector( + {f::GradVarName("in3_1"), f::GradVarName("in3_2")})); + delete forw_op; + delete grad_op; +} \ No newline at end of file