提交 dfb4ea76 编写于 作者: Q qingqing01

make unit test of backward_test pass.

上级 88104905
...@@ -25,7 +25,7 @@ template <typename Map, typename T> ...@@ -25,7 +25,7 @@ template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) { static void ForEachVarName(Map& names, T callback) {
for (auto& name : names) { for (auto& name : names) {
for (auto& n : name.second) { for (auto& n : name.second) {
if (callback(n)) break; if (callback(n)) return;
} }
} }
} }
...@@ -33,12 +33,12 @@ static void ForEachVarName(Map& names, T callback) { ...@@ -33,12 +33,12 @@ static void ForEachVarName(Map& names, T callback) {
static bool AllInSet( static bool AllInSet(
const std::unordered_map<std::string, std::vector<std::string>>& names, const std::unordered_map<std::string, std::vector<std::string>>& names,
const std::string& suffix, const std::unordered_set<std::string>& set) { const std::string& suffix, const std::unordered_set<std::string>& set) {
bool ret_val = true; bool all_in_set = true;
ForEachVarName(names, [&ret_val, &set, &suffix](const std::string& n) { ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
ret_val = set.find(n + suffix) == set.end(); all_in_set = set.find(n + suffix) != set.end();
return !ret_val; return !all_in_set;
}); });
return ret_val; return all_in_set;
} }
static std::shared_ptr<OperatorBase> NOP() { static std::shared_ptr<OperatorBase> NOP() {
......
...@@ -82,11 +82,11 @@ class FcOp : public operators::NetOp { ...@@ -82,11 +82,11 @@ class FcOp : public operators::NetOp {
AddOp(OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}}, {{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {})); {{"Out", {Output("mul_result")}}}, {}));
auto b_name = Input("b"); auto input_b = Inputs("b");
std::string before_act = "mul_result"; std::string before_act = "mul_result";
if (b_name != kEmptyVarName) { if (input_b.size() != 0) {
AddOp(OpRegistry::CreateOp( AddOp(OpRegistry::CreateOp(
"rowwise_add", {{"X", {Output("mul_result")}}, {"b", {b_name}}}, "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
{{"Out", {Output("add_result")}}}, {})); {{"Out", {Output("add_result")}}}, {}));
before_act = "add_result"; before_act = "add_result";
} else { } else {
...@@ -166,209 +166,242 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker); ...@@ -166,209 +166,242 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
// TEST(Backward, simple_op_grad) { TEST(Backward, simple_op_grad) {
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp(
// ASSERT_NE(fwd, nullptr); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
// auto gop = f::OpRegistry::CreateGradOp(*fwd); ASSERT_NE(fwd, nullptr);
// ASSERT_EQ(4UL, gop->inputs_.size()); auto gop = f::OpRegistry::CreateGradOp(*fwd);
// ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]); ASSERT_EQ(1UL, gop->inputs_.size());
// ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->type_);
// ASSERT_EQ(f::GradVarName("X"), gop->outputs_[0]); ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X")));
// ASSERT_EQ(f::GradVarName("b"), gop->outputs_[1]); ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b")));
// }
// ASSERT_EQ(f::GradVarName("X"), gop->Output(f::GradVarName("X")));
//} TEST(Backward, simple_op_not_need_grad) {
// auto fwd = f::OpRegistry::CreateOp(
// TEST(Backward, simple_op_not_need_grad) { "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr);
// ASSERT_NE(fwd, nullptr); auto gop = f::Backward(*fwd, {"x"});
// auto gop = f::Backward(*fwd, {"X"}); ASSERT_EQ(gop->Output(f::GradVarName("X")), f::kEmptyVarName);
// ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
// f::GradVarName("X")), auto no_input_gop = f::Backward(*fwd, {"x", "b"});
// gop->outputs_.end()); ASSERT_NE(no_input_gop, nullptr);
// ASSERT_TRUE(no_input_gop->IsNetOp());
// auto no_input_gop = f::Backward(*fwd, {"X", "b"}); ASSERT_EQ(0UL,
// ASSERT_NE(no_input_gop, nullptr); std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
// ASSERT_TRUE(no_input_gop->IsNetOp()); }
// ASSERT_EQ(0UL,
// std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size()); TEST(Backward, net_fc_backward_normal) {
//} std::shared_ptr<f::OperatorBase> fwd =
// f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
// TEST(Backward, net_fc_backward_normal) { {{"mul_result", {"mul_res"}},
// std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( {"add_result", {"add_re"}},
// "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); {"Out", {"out"}}},
// ASSERT_NE(fwd, nullptr); {});
// std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); ASSERT_NE(fwd, nullptr);
// ASSERT_TRUE(gop->IsNetOp()); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
// auto net = static_cast<ops::NetOp *>(gop.get()); ASSERT_TRUE(gop->IsNetOp());
// auto net = static_cast<ops::NetOp *>(gop.get());
// ASSERT_NO_THROW(net->DebugString());
// ASSERT_NO_THROW(net->DebugString());
// ASSERT_EQ(3UL, net->ops_.size());
// ASSERT_EQ(3UL, net->ops_.size());
// f::OperatorBase &d_sigmoid = *net->ops_[0];
// ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); f::OperatorBase &d_sigmoid = *net->ops_[0];
// ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
// f::OperatorBase &d_add = *net->ops_[1];
// ASSERT_EQ("rowwise_add_grad", d_add.type_); f::OperatorBase &d_add = *net->ops_[1];
// ASSERT_EQ("rowwise_add_grad", d_add.type_);
// f::OperatorBase &d_mul = *net->ops_[2];
// ASSERT_EQ("mul_grad", d_mul.type_); f::OperatorBase &d_mul = *net->ops_[2];
//} ASSERT_EQ("mul_grad", d_mul.type_);
// }
// TEST(Backward, net_fc_backward_not_have_b) {
// std::shared_ptr<f::OperatorBase> fwd = TEST(Backward, net_fc_backward_not_have_b) {
// f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName}, std::shared_ptr<f::OperatorBase> fwd =
// {"mul_result", "add_result", "tmp"}, {}); f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {}}},
// ASSERT_NE(fwd, nullptr); {{"mul_result", {"mul_res"}},
// std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); {"add_result", {"add_res"}},
// ASSERT_TRUE(gop->IsNetOp()); {"Out", {"tmp"}}},
// auto net = static_cast<ops::NetOp *>(gop.get()); {});
// ASSERT_NE(fwd, nullptr);
// ASSERT_NO_THROW(net->DebugString()); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
// ASSERT_TRUE(gop->IsNetOp());
// ASSERT_EQ(2UL, net->ops_.size()); auto net = static_cast<ops::NetOp *>(gop.get());
//
// f::OperatorBase &d_sigmoid = *net->ops_[0]; ASSERT_NO_THROW(net->DebugString());
// ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
// ASSERT_EQ(2UL, net->ops_.size());
// f::OperatorBase &d_mul = *net->ops_[1];
// ASSERT_EQ("mul_grad", d_mul.type_); f::OperatorBase &d_sigmoid = *net->ops_[0];
//} ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
//
// TEST(Backward, net_input_of_network_not_need_grad) { f::OperatorBase &d_mul = *net->ops_[1];
// ops::NetOp net; ASSERT_EQ("mul_grad", d_mul.type_);
// net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, }
// {"mul_tmp_0", "add_tmp_0", "hidden0"},
// {})); TEST(Backward, net_input_of_network_not_need_grad) {
// net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, ops::NetOp net;
// {"mul_tmp_1", "add_tmp_1", "hidden1"}, net.AddOp(f::OpRegistry::CreateOp(
// {})); "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}},
// net.CompleteAddOp(); {{"mul_result", {"mul_tmp_0"}},
// auto bwd = Backward(net, {"X"}); // X@GRAD is not need. {"add_result", {"add_tmp_0"}},
// ASSERT_TRUE(bwd->IsNetOp()); {"Out", {"hidden0"}}},
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); {}));
// net.AddOp(f::OpRegistry::CreateOp(
// std::unordered_set<std::string> all_output = "fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}},
// std::unordered_set<std::string>( {{"mul_result", {"mul_tmp_1"}},
// bwd_net->outputs_.begin(), bwd_net->outputs_.end()); {"add_result", {"add_tmp_1"}},
// all_output.erase(f::kEmptyVarName); {"Out", {"hidden1"}}},
// {}));
// for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { net.CompleteAddOp();
// ASSERT_NE(all_output.find(f::GradVarName(out)), all_output.end()); auto bwd = Backward(net, {"x"}); // x@GRAD is not need.
// } ASSERT_TRUE(bwd->IsNetOp());
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
// // Not Generated X
// ASSERT_EQ(all_output.find(f::GradVarName("X")), all_output.end()); auto output_vars = bwd_net->OutputVars(true);
// std::unordered_set<std::string> all_outputs =
// ASSERT_EQ(2UL, bwd_net->ops_.size()); std::unordered_set<std::string>(output_vars.begin(), output_vars.end());
// ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); all_outputs.erase(f::kEmptyVarName);
// auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
// ASSERT_EQ(3UL, first_fc_grad->ops_.size()); for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
// ASSERT_EQ(f::kEmptyVarName, ASSERT_NE(all_outputs.find(f::GradVarName(out)), all_outputs.end());
// first_fc_grad->ops_[2]->Output(f::GradVarName("A"))); }
//}
// // Not Generated X
// TEST(Backward, net_shared_weight) { ASSERT_EQ(all_outputs.find(f::GradVarName("X")), all_outputs.end());
// ops::NetOp net;
// net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); ASSERT_EQ(2UL, bwd_net->ops_.size());
// net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
// net.CompleteAddOp(); auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
// ASSERT_EQ(3UL, first_fc_grad->ops_.size());
// auto bwd = f::Backward(net, {}); ASSERT_EQ(f::kEmptyVarName,
// ASSERT_TRUE(bwd->IsNetOp()); first_fc_grad->ops_[2]->Output(f::GradVarName("X")));
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); }
// ASSERT_EQ(3UL, bwd_net->ops_.size());
// ASSERT_EQ("add", bwd_net->ops_[2]->type_); TEST(Backward, net_shared_weight) {
//} ops::NetOp net;
// net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
// TEST(Backward, op_register_grad_not_for_network) { {{"Out", {"out"}}}, {}));
// auto fwd = f::OpRegistry::CreateOp( net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
// "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"}, {{"Out", {"FinalOut"}}}, {}));
// {{"temporary_index", std::vector<int>{0, 1}}}); net.CompleteAddOp();
//
// ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); auto bwd = f::Backward(net, {});
//} ASSERT_TRUE(bwd->IsNetOp());
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
// TEST(Backward, op_all_input_are_not_need) { ASSERT_EQ(3UL, bwd_net->ops_.size());
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_EQ("add", bwd_net->ops_[2]->type_);
// auto backward = f::Backward(*fwd, {"X", "b"}); }
// ASSERT_TRUE(backward->IsNetOp());
// auto net = static_cast<ops::NetOp *>(backward.get()); TEST(Backward, op_register_grad_not_for_network) {
// ASSERT_TRUE(net->ops_.empty()); auto fwd =
//} f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
// {{"mul_result", {"mul_out"}},
// TEST(Backward, op_all_output_are_not_need) { {"add_result", {"add_out"}},
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); {"Out", {"out1"}}},
// auto backward = f::Backward(*fwd, {"Out"}); {{"temporary_index", std::vector<int>{0, 1}}});
// ASSERT_TRUE(backward->IsNetOp());
// auto net = static_cast<ops::NetOp *>(backward.get()); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
// ASSERT_TRUE(net->ops_.empty()); }
//}
// TEST(Backward, op_all_input_are_not_need) {
// TEST(Backward, op_part_of_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp(
// auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
// auto backward = f::Backward(*fwd, {"Z"}); auto backward = f::Backward(*fwd, {"x", "b"});
// ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
// auto net = static_cast<ops::NetOp *>(backward.get()); auto net = static_cast<ops::NetOp *>(backward.get());
// ASSERT_EQ(net->ops_.size(), 2UL); ASSERT_TRUE(net->ops_.empty());
// }
// auto &fill_zero = *net->ops_[0];
// ASSERT_EQ("fill_zeros_like", fill_zero.type_); TEST(Backward, op_all_output_are_not_need) {
// ASSERT_EQ(1UL, fill_zero.inputs_.size()); auto fwd = f::OpRegistry::CreateOp(
// ASSERT_EQ("Z", fill_zero.inputs_[0]); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
// ASSERT_EQ(1UL, fill_zero.outputs_.size()); auto backward = f::Backward(*fwd, {"out"});
// ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.outputs_[0]); ASSERT_TRUE(backward->IsNetOp());
// auto net = static_cast<ops::NetOp *>(backward.get());
// auto &d_many_out = *net->ops_[1]; ASSERT_TRUE(net->ops_.empty());
// ASSERT_EQ("many_output_op_grad", d_many_out.type_); }
// ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
// ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, TEST(Backward, op_part_of_output_are_not_need) {
// d_many_out.Input(f::GradVarName("z"))); auto fwd = f::OpRegistry::CreateOp("many_output_op", {{"x", {"X"}}},
// ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y"))); {{"y", {"Y"}}, {"z", {"Z"}}}, {});
// ASSERT_EQ(f::GradVarName("X"), d_many_out.Output(f::GradVarName("x"))); auto backward = f::Backward(*fwd, {"Z"});
//} ASSERT_TRUE(backward->IsNetOp());
// auto net = static_cast<ops::NetOp *>(backward.get());
// TEST(Backward, op_part_of_input_are_not_need) { ASSERT_EQ(net->ops_.size(), 2UL);
// auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
// auto backward = f::Backward(*fwd, {"a"}); auto &fill_zero = *net->ops_[0];
// auto &grad_mul = *backward; ASSERT_EQ("fill_zeros_like", fill_zero.type_);
// ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(1UL, fill_zero.Inputs("Src").size());
// ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ("Z", fill_zero.Input("Src"));
// ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size());
// ASSERT_EQ(grad_mul.Output(f::GradVarName("A")), f::kEmptyVarName); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst"));
// ASSERT_EQ(grad_mul.Output(f::GradVarName("B")), f::GradVarName("b"));
// ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out")); auto &d_many_out = *net->ops_[1];
// ASSERT_EQ(grad_mul.Input("A"), "a"); ASSERT_EQ("many_output_op_grad", d_many_out.type_);
// ASSERT_EQ(grad_mul.Input("B"), "b"); ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
// ASSERT_EQ(grad_mul.Input("Out"), "out"); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
//} d_many_out.Input(f::GradVarName("z")));
// ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
// TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ASSERT_EQ(f::GradVarName("X"), d_many_out.Output(f::GradVarName("x")));
// ops::NetOp net; }
// net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
// {"mul_out1", "add_out1", "out1"}, {})); TEST(Backward, op_part_of_input_are_not_need) {
// net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, auto fwd = f::OpRegistry::CreateOp("mul", {{"X", {"a"}}, {"Y", {"b"}}},
// {"mul_out2", "tmp_out2", "out2"}, {})); {{"Out", {"out"}}}, {});
// net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, auto backward = f::Backward(*fwd, {"a"});
// {"mul_out3", "tmp_out3", "out3"}, {})); auto &grad_mul = *backward;
// net.CompleteAddOp(); ASSERT_EQ(grad_mul.type_, "mul_grad");
// auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
// ASSERT_TRUE(backward->IsNetOp()); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
// auto bwd_net = static_cast<ops::NetOp *>(backward.get()); ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName);
// ASSERT_EQ(bwd_net->ops_.size(), 3UL); ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b"));
// auto &grad_fc = *bwd_net->ops_[0]; ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
// EXPECT_EQ(grad_fc.inputs_.size(), ASSERT_EQ(grad_mul.Input("X"), "a");
// 3UL /* external input number */ ASSERT_EQ(grad_mul.Input("Y"), "b");
// + 1UL /* external output number*/ ASSERT_EQ(grad_mul.Input("Out"), "out");
// + 1UL /* number of gradient of external output*/ }
// + 2U /* internal variable number*/);
// EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
// + 2UL /* input number of rowwise_add ops::NetOp net;
// */ net.AddOp(f::OpRegistry::CreateOp(
// + 1UL /* input number of sigmod */); "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}},
// EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL); {{"mul_result", {"mul_out1"}},
// EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); {"add_result", {"add_out1"}},
// EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); {"Out", {"out1"}}},
// EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); {}));
//} net.AddOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_out2"}},
{"add_result", {"tmp_out2"}},
{"Out", {"out2"}}},
{}));
net.AddOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}},
{{"mul_result", {"mul_out3"}},
{"add_result", {"tmp_out3"}},
{"Out", {"out3"}}},
{}));
net.CompleteAddOp();
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_["all"].size(),
2UL /* external input number */
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_["all"].size(),
2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add
*/
+ 1UL /* input number of sigmod */);
EXPECT_EQ(bwd_net->ops_[1]->inputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_["all"].size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_["all"].size(), 0UL);
}
...@@ -43,7 +43,7 @@ std::unordered_map<std::string, OpProto>& OpProtos() { ...@@ -43,7 +43,7 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name); auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_, PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have input %s", type_,
name); name);
PADDLE_ENFORCE_EQ(it->second.size(), 1UL, PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
"Op %s input %s should contain only one variable", type_, "Op %s input %s should contain only one variable", type_,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册