From 831d4e1c85dedc2bca8cc997ccc612208dc05c38 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 26 Jul 2017 19:37:40 +0800 Subject: [PATCH] Refining Unittest --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/backward_test.cc | 142 ++++++++++++++++++++++- paddle/framework/grad_op_builder.cc | 19 ++- paddle/framework/grad_op_builder.h | 4 +- paddle/framework/grad_op_builder_test.cc | 2 +- paddle/framework/op_registry.h | 7 +- 6 files changed, 152 insertions(+), 24 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 66f516a963..7febaaa527 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -33,4 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_library(backward SRCS backward.cc DEPS net) -cc_test(backward_test SRCS backward_test.cc DEPS net) +cc_test(backward_test SRCS backward_test.cc DEPS backward) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index b2286facfe..cc00279db5 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -12,8 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/framework/backward.h" #include +#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" + namespace paddle { namespace framework { @@ -24,10 +27,9 @@ class EmptyOp : public OperatorBase { const platform::DeviceContext &dev_ctx) const override {} }; -class RowwiseAddOp : public EmptyOp {}; -class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { +class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: - RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input X of Add").IgnoreGradient(); AddInput("b", "Bias of Add").IgnoreGradient(); @@ -36,15 +38,143 @@ class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { } }; -class RowwiseAddGradOp : public EmptyOp {}; +class MulOpMaker : public OpProtoAndCheckerMaker { + public: + MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("A", "A"); + AddInput("B", "B"); + AddOutput("Out", "Out"); + AddComment("Mul"); + } +}; + +class SigmoidOpMaker : public OpProtoAndCheckerMaker { + public: + SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "X"); + AddOutput("Y", "Y"); + AddComment("Sigmoid"); + } +}; + +class FcOp : public NetOp { + public: + void Init() override { + AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, + {Output("before_act")}, {})); + auto b_name = Input("b"); + if (b_name != EMPTY_VAR_NAME()) { + AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name}, + {Output("before_act")}, {})); + } + AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")}, + {Output("Out")}, {})); + CompleteAddOp(false); + } +}; + +class FcOpMaker : public OpProtoAndCheckerMaker { + public: + FcOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x"); + AddInput("W", "w"); + AddInput("b", "b"); + AddOutput("before_act", "before act").SetTemporary(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class ManyOutputOpMaker : public OpProtoAndCheckerMaker { + public: + ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("y", "y"); + AddOutput("z", "z"); + AddComment(""); + } +}; + +class FillZeroOpMaker : public OpProtoAndCheckerMaker { + public: + FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("out", "out"); + AddComment(""); + } +}; } // namespace framework } // namespace paddle namespace f = paddle::framework; -REGISTER_OP(rowwise_add, f::RowwiseAddOp, f::RowwiseAddOpMaker); -REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::RowwiseAddGradOp); +using EnforceNotMet = paddle::platform::EnforceNotMet; +REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); +REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); +REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); +REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); +REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); +REGISTER_OP(fc, f::FcOp, f::FcOpMaker); +REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); +REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); +REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); TEST(Backward, simple_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); + auto gop = f::OpRegistry::CreateGradOp(*fwd); + ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); + ASSERT_EQ("rowwise_add_grad", gop->type_); + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); + ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); + + // LOG(INFO) << gop->Output("X" + "@GRAD"); +} + +TEST(Backward, not_for_network) { + auto fwd = + f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, + {{"temporary_index", std::vector{1}}}); + ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); +} + +TEST(Backward, all_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"X", "b"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, all_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"Out"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, part_of_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); + auto backward = f::Backward(*fwd, {"Z"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 2); + + auto &fill_zero = *net->ops_[0]; + ASSERT_EQ("fill_zeros_like", fill_zero.type_); + ASSERT_EQ(1, fill_zero.inputs_.size()); + ASSERT_EQ("Z", fill_zero.inputs_[0]); + ASSERT_EQ(1, fill_zero.outputs_.size()); + ASSERT_EQ("Z@ZERO", fill_zero.outputs_[0]); + + auto &d_many_out = *net->ops_[1]; + ASSERT_EQ("many_output_op_grad", d_many_out.type_); + ASSERT_EQ(1 + 2 + 2, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("Z@ZERO", d_many_out.Input("z@GRAD")); } \ No newline at end of file diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 6235be75f2..dd686cc782 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -20,7 +20,7 @@ namespace framework { OperatorBase* GradOpBuilder::Build() { BuildOpInOutArgList(); - std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_); + std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); grad_op->type_ = grad_op_type; CompleteGradOp(grad_op); @@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, } void GradOpBuilder::BuildOpInOutArgList() { - const OpProto& op_proto = OpRegistry::protos().at(op_->type_); - const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); + const OpProto& op_proto = OpRegistry::protos().at(op_.type_); + const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_)); const std::vector& in_format = - op_->attrs_.count("input_format") - ? op_->GetAttr>("input_format") + op_.attrs_.count("input_format") + ? op_.GetAttr>("input_format") : std::vector(); const std::vector& out_format = - op_->attrs_.count("output_format") - ? op_->GetAttr>("output_format") + op_.attrs_.count("output_format") + ? op_.GetAttr>("output_format") : std::vector(); for (const auto& var : op_proto.inputs()) { arg_list_.emplace_back( @@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } (*varmap)[var_name] = idx++; size_t pre_sz = in_out.size(); - auto base_it = - arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin(); + auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin(); std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::back_inserter(in_out)); if (is_grad) { @@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->attrs_ = op_->attrs_; + grad_op->attrs_ = op_.attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); VarIndexMap* grad_varmap = new VarIndexMap(); diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index 2ecf39479b..cc7a76f372 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -29,7 +29,7 @@ class GradOpBuilder { using VarIndexMap = std::unordered_map; public: - GradOpBuilder(const OperatorBase* op) : op_(op) {} + GradOpBuilder(const OperatorBase& op) : op_(op) {} OperatorBase* Build(); private: @@ -40,7 +40,7 @@ class GradOpBuilder { std::vector& format, VarIndexMap* varmap, int& idx, bool is_grad) const; void CompleteGradOp(OperatorBase* grad_op) const; - const OperatorBase* op_; + const OperatorBase& op_; std::vector> arg_list_; }; diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 288a7841cd..e9cf3b9798 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -11,7 +11,7 @@ namespace framework { TEST(GradOpBuilder, AddTwo) { std::shared_ptr add_op( OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(add_op); + std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(*add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index e4ac8a6e76..cee20b1112 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -303,11 +303,10 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static std::shared_ptr CreateGradOp( - std::shared_ptr op) { - PADDLE_ENFORCE(!op->IsNetOp(), + static std::shared_ptr CreateGradOp(const OperatorBase& op) { + PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); - GradOpBuilder builder(op.get()); + GradOpBuilder builder(op); std::shared_ptr grad_op(builder.Build()); grad_op->Init(); return grad_op; -- GitLab