提交 658588a6 编写于 作者: D dongzhihong

"format test case"

上级 b2e1c48e
......@@ -167,15 +167,28 @@ TEST(Backward, simple_op_not_need_grad) {
auto gop = f::Backward(*fwd, {"X"});
LOG(INFO) << "full " << gop->DebugString();
ASSERT_NE(std::find(gop->outputs_.begin(), gop->outputs_.end(),
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
std::string("X") + f::OperatorBase::GRAD_VAR_SUFFIX()),
gop->outputs_.end());
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
LOG(INFO) << "no input gop " << no_input_gop->DebugString();
LOG(INFO) << "no input gop " << gop->DebugString();
ASSERT_NE(no_input_gop, nullptr);
ASSERT_EQ(std::vector<std::string>{}, no_input_gop->outputs_);
typedef std::vector<std::string> Vec;
auto vector_equal = [](const Vec &l, const Vec &r) {
return l.size() == r.size();
for (size_t i = 0; i < l.size(); ++i) {
if (l[i] != r[i]) return false;
}
return true;
};
ASSERT_EQ(vector_equal(std::vector<std::string>{}, no_input_gop->outputs_),
true);
ASSERT_EQ(
std::vector<std::string>{"Out" + f::OperatorBase::GRAD_VAR_SUFFIX()},
no_input_gop->inputs_);
vector_equal(
std::vector<std::string>{"Out" + f::OperatorBase::GRAD_VAR_SUFFIX()},
no_input_gop->inputs_),
true);
// auto no_output_gop = f::Backward(*fwd, {"Out"});
// ASSERT_EQ(std::vector<std::string>{"X" +
// f::OperatorBase::GRAD_VAR_SUFFIX(), "b"})
......@@ -251,9 +264,8 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(
f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, net_shared_weight) {
......@@ -266,13 +278,14 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
LOG(INFO) << bwd_net->DebugString();
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
}
TEST(Backward, op_register_grad_not_for_network) {
auto fwd = f::OpRegistry::CreateOp(
"fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"},
{{"temporary_index", std::vector<int>{1}}});
auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
......@@ -320,9 +333,11 @@ TEST(Backward, op_part_of_output_are_not_need) {
TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"});
ASSERT_TRUE(!backward->IsNetOp());
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 1UL);
auto &grad_mul = *backward;
auto &grad_mul = *net->ops_[0];
ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
......@@ -339,13 +354,10 @@ TEST(Backward, op_part_of_input_are_not_need) {
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
{"mul_out1", "add_out1", "out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
{"mul_out2", "tmp_out2", "out2"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
{"mul_out3", "tmp_out3", "out3"}, {}));
net.CompleteAddOp();
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {}));
net.CompleteAddOp(false);
auto backward = f::Backward(net, {"out2"});
ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册