From d2c2f7855185ec7b683cba02d0e9ce9e42db1257 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 14 Aug 2017 17:47:16 +0800 Subject: [PATCH] change backward --- paddle/framework/backward.cc | 26 ++++++++++---------- paddle/framework/backward_test.cc | 40 +++++++++++++++---------------- paddle/framework/operator.h | 1 + 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 315bdde76d..a82dc4ef4b 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -22,7 +22,7 @@ namespace paddle { namespace framework { template -static void ForEachVarName(Map& names, T callback) { +static void ForEachVarName(const Map& names, T callback) { for (auto& name : names) { for (auto& n : name.second) { if (callback(n)) return; @@ -43,7 +43,7 @@ static bool AllInSet( static std::shared_ptr NOP() { auto net_op = std::make_shared(); - net_op->type_ = "@NOP@"; + net_op->SetType("@NOP@"); net_op->CompleteAddOp(); return net_op; } @@ -69,15 +69,15 @@ std::shared_ptr BackwardRecursive( // If all input gradients of forwarding operator do not need to calculate, // just return an NOP. Not return null ptr because NOP does not take // too much time for calculation, but it is useful for simplifying logic. - if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) { + if (AllInSet(forwardOp.Inputs(), kGradVarSuffix, no_grad_names)) { return NOP(); } // All output gradients of forwarding operator do not need to calculate. // Then all input gradients cannot be computed at all, and we put them into // `no_grad_names` set. Return an NOP. - if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) { - ForEachVarName(forwardOp.inputs_, + if (AllInSet(forwardOp.Outputs(), kGradVarSuffix, no_grad_names)) { + ForEachVarName(forwardOp.Inputs(), [&no_grad_names](const std::string& name) -> bool { no_grad_names.insert(GradVarName(name)); return false; @@ -103,7 +103,7 @@ std::shared_ptr BackwardRecursive( auto fwd = *it; auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); net->AddOp(bwd); - ForEachVarName(bwd->outputs_, + ForEachVarName(bwd->Outputs(), [&dup_output_ops, local_op_id](const std::string& out) { dup_output_ops[out].emplace_back(local_op_id); return false; @@ -144,13 +144,13 @@ std::shared_ptr BackwardRecursive( } else { std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); - ForEachVarName(grad_op->inputs_, [&no_grad_names, - &net](std::string& grad_input) { + ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, + grad_op](const std::string& grad_input) { if (no_grad_names.count(grad_input)) { // +1 for \0 std::string prefix = grad_input.substr( 0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); - grad_input = prefix + kZeroVarSuffix; + grad_op->Rename(grad_input, prefix + kZeroVarSuffix); // If part of input gradient of that operator is not calculated, fill // zero variables to that input gradient. @@ -160,10 +160,10 @@ std::shared_ptr BackwardRecursive( return false; }); - ForEachVarName(grad_op->outputs_, - [&no_grad_names](std::string& grad_output) { + ForEachVarName(grad_op->Outputs(), + [&no_grad_names, &grad_op](const std::string& grad_output) { if (no_grad_names.count(grad_output)) { - grad_output = kEmptyVarName; + grad_op->Rename(grad_output, kEmptyVarName); } return false; }); @@ -173,7 +173,7 @@ std::shared_ptr BackwardRecursive( } net->AddOp(grad_op); } - net->type_ = "@GENERATED_BACKWARD@"; + net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); return net; } diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index e1e5379009..5874ef2f1f 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -173,8 +173,8 @@ TEST(Backward, simple_op_grad) { "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1UL, gop->inputs_.size()); - ASSERT_EQ("rowwise_add_grad", gop->type_); + ASSERT_EQ(1UL, gop->Inputs().size()); + ASSERT_EQ("rowwise_add_grad", gop->Type()); ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); } @@ -210,13 +210,13 @@ TEST(Backward, net_fc_backward_normal) { ASSERT_EQ(3UL, net->ops_.size()); f::OperatorBase &d_sigmoid = *net->ops_[0]; - ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + ASSERT_EQ("sigmoid_grad", d_sigmoid.Type()); f::OperatorBase &d_add = *net->ops_[1]; - ASSERT_EQ("rowwise_add_grad", d_add.type_); + ASSERT_EQ("rowwise_add_grad", d_add.Type()); f::OperatorBase &d_mul = *net->ops_[2]; - ASSERT_EQ("mul_grad", d_mul.type_); + ASSERT_EQ("mul_grad", d_mul.Type()); } TEST(Backward, net_fc_backward_not_have_b) { @@ -236,10 +236,10 @@ TEST(Backward, net_fc_backward_not_have_b) { ASSERT_EQ(2UL, net->ops_.size()); f::OperatorBase &d_sigmoid = *net->ops_[0]; - ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + ASSERT_EQ("sigmoid_grad", d_sigmoid.Type()); f::OperatorBase &d_mul = *net->ops_[1]; - ASSERT_EQ("mul_grad", d_mul.type_); + ASSERT_EQ("mul_grad", d_mul.Type()); } TEST(Backward, net_input_of_network_not_need_grad) { @@ -293,7 +293,7 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - ASSERT_EQ("add", bwd_net->ops_[2]->type_); + ASSERT_EQ("add", bwd_net->ops_[2]->Type()); } TEST(Backward, op_register_grad_not_for_network) { @@ -334,15 +334,15 @@ TEST(Backward, op_part_of_output_are_not_need) { ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; - ASSERT_EQ("fill_zeros_like", fill_zero.type_); + ASSERT_EQ("fill_zeros_like", fill_zero.Type()); ASSERT_EQ(1UL, fill_zero.Inputs("Src").size()); ASSERT_EQ("Z", fill_zero.Input("Src")); ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size()); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst")); auto &d_many_out = *net->ops_[1]; - ASSERT_EQ("many_output_op_grad", d_many_out.type_); - ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("many_output_op_grad", d_many_out.Type()); + ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.Inputs().size()); // I/O/OG ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, d_many_out.Input(f::GradVarName("z"))); ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y"))); @@ -354,9 +354,9 @@ TEST(Backward, op_part_of_input_are_not_need) { {{"Out", {"out"}}}, {}); auto backward = f::Backward(*fwd, {"a"}); auto &grad_mul = *backward; - ASSERT_EQ(grad_mul.type_, "mul_grad"); - ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); - ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Type(), "mul_grad"); + ASSERT_EQ(grad_mul.Inputs().size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.Outputs().size(), 2UL); ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName); ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b")); ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out")); @@ -394,18 +394,18 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { auto &grad_fc = *bwd_net->ops_[0]; const char *all = paddle::operators::NetOp::kAll; - EXPECT_EQ(grad_fc.inputs_[all].size(), + EXPECT_EQ(grad_fc.Inputs(all).size(), 2UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ + 2U /* internal variable number*/); - EXPECT_EQ(grad_fc.outputs_[all].size(), + EXPECT_EQ(grad_fc.Outputs(all).size(), 2UL /* input number of mul*/ + 2UL /* input number of rowwise_add */ + 1UL /* input number of sigmod */); - EXPECT_EQ(bwd_net->ops_[1]->inputs_[all].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[1]->outputs_[all].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->inputs_[all].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->outputs_[all].size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL); } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 038e6fe7a2..acff4f0ca0 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -121,6 +121,7 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; const std::string& Type() const { return type_; } + void SetType(const std::string& type) { type_ = type; } const AttributeMap& Attrs() const { return attrs_; } protected: -- GitLab