diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 404adb4f379e9cfe8c54de4c6022f8be75a4b8a6..6f86b62b48d1d00d5c565241d5e78c57f1adf5db 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -108,6 +108,16 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { AddComment(""); } }; + +class AddOpMaker : public OpProtoAndCheckerMaker { + public: + AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x").SetMultiple(); + AddOutput("Y", "y"); + AddComment(""); + } +}; } // namespace framework } // namespace paddle @@ -123,11 +133,14 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); +REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); +REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); -TEST(Backward, simple_grad) { +TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); + ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); @@ -138,14 +151,99 @@ TEST(Backward, simple_grad) { // LOG(INFO) << gop->Output("X" + "@GRAD"); } -TEST(Backward, not_for_network) { +TEST(Backward, net_fc_backward_normal) { + std::shared_ptr fwd = + f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(3UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_add = *net->ops_[1]; + ASSERT_EQ("rowwise_add_grad", d_add.type_); + + f::OperatorBase &d_mul = *net->ops_[2]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, net_fc_backward_not_have_b) { + std::shared_ptr fwd = f::OpRegistry::CreateOp( + "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(2UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_mul = *net->ops_[1]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, net_input_of_network_not_need_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {})); + net.AddOp( + f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {})); + net.CompleteAddOp(); + auto bwd = Backward(net, {"X"}); // X@GRAD is not need. + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + + std::unordered_set all_output = std::unordered_set( + bwd_net->outputs_.begin(), bwd_net->outputs_.end()); + all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); + + for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { + ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + } + + // Not Generated X + ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + + ASSERT_EQ(2, bwd_net->ops_.size()); + ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); + auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); + ASSERT_EQ(3, first_fc_grad->ops_.size()); + ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, net_shared_weight) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); + net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); + net.CompleteAddOp(); + + auto bwd = f::Backward(net, {}); + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + ASSERT_EQ(3UL, bwd_net->ops_.size()); + ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); +} + +TEST(Backward, op_register_grad_not_for_network) { auto fwd = f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, {{"temporary_index", std::vector{1}}}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } -TEST(Backward, all_input_are_not_need) { +TEST(Backward, op_all_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"X", "b"}); ASSERT_TRUE(backward->IsNetOp()); @@ -153,7 +251,7 @@ TEST(Backward, all_input_are_not_need) { ASSERT_TRUE(net->ops_.empty()); } -TEST(Backward, all_output_are_not_need) { +TEST(Backward, op_all_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"Out"}); ASSERT_TRUE(backward->IsNetOp()); @@ -161,22 +259,78 @@ TEST(Backward, all_output_are_not_need) { ASSERT_TRUE(net->ops_.empty()); } -TEST(Backward, part_of_output_are_not_need) { +TEST(Backward, op_part_of_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); - ASSERT_EQ(net->ops_.size(), 2); + ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; ASSERT_EQ("fill_zeros_like", fill_zero.type_); - ASSERT_EQ(1, fill_zero.inputs_.size()); + ASSERT_EQ(1UL, fill_zero.inputs_.size()); ASSERT_EQ("Z", fill_zero.inputs_[0]); - ASSERT_EQ(1, fill_zero.outputs_.size()); - ASSERT_EQ("Z@ZERO", fill_zero.outputs_[0]); + ASSERT_EQ(1UL, fill_zero.outputs_.size()); + ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]); auto &d_many_out = *net->ops_[1]; ASSERT_EQ("many_output_op_grad", d_many_out.type_); - ASSERT_EQ(1 + 2 + 2, d_many_out.inputs_.size()); // I/O/OG - ASSERT_EQ("Z@ZERO", d_many_out.Input("z@GRAD")); + ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), + d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(), + d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), + d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, op_part_of_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); + auto backward = f::Backward(*fwd, {"a"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 1UL); + + auto &grad_mul = *net->ops_[0]; + ASSERT_EQ(grad_mul.type_, "mul_grad"); + ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("A"), "a"); + ASSERT_EQ(grad_mul.Input("B"), "b"); + ASSERT_EQ(grad_mul.Input("Out"), "out"); +} + +TEST(Backward, linear_net_intermediate_variable_has_no_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {})); + net.CompleteAddOp(false); + auto backward = f::Backward(net, {"out2"}); + ASSERT_TRUE(backward->IsNetOp()); + auto bwd_net = static_cast(backward.get()); + ASSERT_EQ(bwd_net->ops_.size(), 1UL); + + auto &grad_fc = *bwd_net->ops_[0]; + ASSERT_EQ(grad_fc.type_, "fc_grad"); + ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 1UL + 1UL); + ASSERT_EQ(grad_fc.outputs_.size(), 3UL); + ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("X"), "out2"); + ASSERT_EQ(grad_fc.Input("W"), "w3"); + ASSERT_EQ(grad_fc.Input("b"), "b3"); + ASSERT_EQ(grad_fc.Input("Out"), "out3"); } \ No newline at end of file diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0a14dc21144153f9a45d5227e54102983c6c2659..644460ee4735a8acc7e5ca32d7983d945fa13826 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) +op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 3df3a2cfab6d0c69d660eb78c103738eb0ccc627..d641bc4adaf8c7a84f5dab37632108d929a64730 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -19,16 +19,16 @@ limitations under the License. */ namespace paddle { namespace operators { -class FillZerosLike : public framework::OperatorWithKernel { +class FillZerosLikeOp : public framework::OperatorWithKernel { protected: void InferShape( const std::vector &inputs, const std::vector &outputs) const override { PADDLE_ENFORCE(inputs.size() == 1, - "Input size of FillZerosLike must be one."); + "Input size of FillZerosLikeOp must be one."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, - "Outputs of FillZerosLike must all be set."); + "Outputs of FillZerosLikeOp must all be set."); outputs[0]->Resize(inputs[0]->dims()); } }; @@ -44,7 +44,7 @@ public: Fill up a vriable with zeros. The output will have the same size with input. -)DOC") +)DOC"); } }; } // namespace operators @@ -53,6 +53,6 @@ The output will have the same size with input. REGISTER_OP(fill_zeros_like, paddle::operators::FillZerosLikeOp, paddle::operators::FillZerosLikeOpMaker); -EGISTER_OP_CPU_KERNEL( +REGISTER_OP_CPU_KERNEL( fill_zeros_like, - paddle::operators::FillZerosLikeKernal); \ No newline at end of file + paddle::operators::FillZerosLikeKernel);