提交 b2e1c48e 编写于 作者: D dongzhihong

Merge remote-tracking branch 'reyoung/feature/backward' into feature/backward

...@@ -53,11 +53,6 @@ static std::shared_ptr<OperatorBase> EmptyOp() { ...@@ -53,11 +53,6 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
static std::shared_ptr<OperatorBase> BackwardImpl( static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// struct OpIdentity {
// size_t local_op_id;
// size_t op_output_offset;
// };
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
return EmptyOp(); return EmptyOp();
...@@ -87,7 +82,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -87,7 +82,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it) { ++it) {
auto fwd = *it; auto fwd = *it;
auto bwd = Backward(*fwd, no_grad_names); auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd); net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) { for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id);
...@@ -136,6 +131,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -136,6 +131,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
{grad_input}, {})); {grad_input}, {}));
} }
} }
for (std::string& grad_output : grad_op->outputs_) { for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) { if (no_grad_names.count(grad_output)) {
grad_output = OperatorBase::EMPTY_VAR_NAME(); grad_output = OperatorBase::EMPTY_VAR_NAME();
......
...@@ -251,8 +251,9 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -251,8 +251,9 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), ASSERT_EQ(
first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
} }
TEST(Backward, net_shared_weight) { TEST(Backward, net_shared_weight) {
...@@ -265,13 +266,12 @@ TEST(Backward, net_shared_weight) { ...@@ -265,13 +266,12 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get()); auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
LOG(INFO) << bwd_net->DebugString();
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
auto fwd = auto fwd = f::OpRegistry::CreateOp(
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, "fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"},
{{"temporary_index", std::vector<int>{1}}}); {{"temporary_index", std::vector<int>{1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
} }
...@@ -320,11 +320,9 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -320,11 +320,9 @@ TEST(Backward, op_part_of_output_are_not_need) {
TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"}); auto backward = f::Backward(*fwd, {"a"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(!backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 1UL);
auto &grad_mul = *net->ops_[0]; auto &grad_mul = *backward;
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
...@@ -341,10 +339,13 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -341,10 +339,13 @@ TEST(Backward, op_part_of_input_are_not_need) {
TEST(Backward, linear_net_intermediate_variable_has_no_grad) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
f::NetOp net; f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {})); {"mul_out1", "add_out1", "out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
net.CompleteAddOp(false); {"mul_out2", "tmp_out2", "out2"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
{"mul_out3", "tmp_out3", "out3"}, {}));
net.CompleteAddOp();
auto backward = f::Backward(net, {"out2"}); auto backward = f::Backward(net, {"out2"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get()); auto bwd_net = static_cast<f::NetOp *>(backward.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册