提交 4461f3ca 编写于 作者: F fengjiayi

Merge branch 'feature/backward' of https://github.com/reyoung/Paddle into feature/backward

......@@ -108,6 +108,16 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
AddComment("");
}
};
class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetMultiple();
AddOutput("Y", "y");
AddComment("");
}
};
} // namespace framework
} // namespace paddle
......@@ -123,8 +133,10 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
TEST(Backward, simple_grad) {
TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
......@@ -139,7 +151,7 @@ TEST(Backward, simple_grad) {
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
TEST(Backward, fc_backward_normal) {
TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd =
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
ASSERT_NE(fwd, nullptr);
......@@ -161,7 +173,7 @@ TEST(Backward, fc_backward_normal) {
ASSERT_EQ("mul_grad", d_mul.type_);
}
TEST(Backward, fc_backward_not_have_b) {
TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {});
ASSERT_NE(fwd, nullptr);
......@@ -180,12 +192,12 @@ TEST(Backward, fc_backward_not_have_b) {
ASSERT_EQ("mul_grad", d_mul.type_);
}
TEST(Backward, input_layer_not_need_grad) {
TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {}));
net.AddOp(
f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {}));
net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
......@@ -198,16 +210,40 @@ TEST(Backward, input_layer_not_need_grad) {
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
}
// Not Generated X
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_EQ(2, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3, first_fc_grad->ops_.size());
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, net_shared_weight) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
net.CompleteAddOp();
auto bwd = f::Backward(net, {});
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
}
TEST(Backward, not_for_network) {
TEST(Backward, op_register_grad_not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
TEST(Backward, all_input_are_not_need) {
TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"X", "b"});
ASSERT_TRUE(backward->IsNetOp());
......@@ -215,7 +251,7 @@ TEST(Backward, all_input_are_not_need) {
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, all_output_are_not_need) {
TEST(Backward, op_all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"Out"});
ASSERT_TRUE(backward->IsNetOp());
......@@ -223,7 +259,7 @@ TEST(Backward, all_output_are_not_need) {
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, part_of_output_are_not_need) {
TEST(Backward, op_part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp());
......@@ -248,7 +284,7 @@ TEST(Backward, part_of_output_are_not_need) {
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, part_of_input_are_not_need) {
TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"});
ASSERT_TRUE(backward->IsNetOp());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册