提交 d2c2f785 编写于 作者: Q qiaolongfei

change backward

上级 5d33ef61
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace framework { namespace framework {
template <typename Map, typename T> template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) { static void ForEachVarName(const Map& names, T callback) {
for (auto& name : names) { for (auto& name : names) {
for (auto& n : name.second) { for (auto& n : name.second) {
if (callback(n)) return; if (callback(n)) return;
...@@ -43,7 +43,7 @@ static bool AllInSet( ...@@ -43,7 +43,7 @@ static bool AllInSet(
static std::shared_ptr<OperatorBase> NOP() { static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<operators::NetOp>(); auto net_op = std::make_shared<operators::NetOp>();
net_op->type_ = "@NOP@"; net_op->SetType("@NOP@");
net_op->CompleteAddOp(); net_op->CompleteAddOp();
return net_op; return net_op;
} }
...@@ -69,15 +69,15 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -69,15 +69,15 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic. // too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) { if (AllInSet(forwardOp.Inputs(), kGradVarSuffix, no_grad_names)) {
return NOP(); return NOP();
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) { if (AllInSet(forwardOp.Outputs(), kGradVarSuffix, no_grad_names)) {
ForEachVarName(forwardOp.inputs_, ForEachVarName(forwardOp.Inputs(),
[&no_grad_names](const std::string& name) -> bool { [&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name)); no_grad_names.insert(GradVarName(name));
return false; return false;
...@@ -103,7 +103,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -103,7 +103,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto fwd = *it; auto fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd); net->AddOp(bwd);
ForEachVarName(bwd->outputs_, ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) { [&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
return false; return false;
...@@ -144,13 +144,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -144,13 +144,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
} else { } else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
ForEachVarName(grad_op->inputs_, [&no_grad_names, ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net,
&net](std::string& grad_input) { grad_op](const std::string& grad_input) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
// +1 for \0 // +1 for \0
std::string prefix = grad_input.substr( std::string prefix = grad_input.substr(
0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
grad_input = prefix + kZeroVarSuffix; grad_op->Rename(grad_input, prefix + kZeroVarSuffix);
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient. // zero variables to that input gradient.
...@@ -160,10 +160,10 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -160,10 +160,10 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
return false; return false;
}); });
ForEachVarName(grad_op->outputs_, ForEachVarName(grad_op->Outputs(),
[&no_grad_names](std::string& grad_output) { [&no_grad_names, &grad_op](const std::string& grad_output) {
if (no_grad_names.count(grad_output)) { if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName; grad_op->Rename(grad_output, kEmptyVarName);
} }
return false; return false;
}); });
...@@ -173,7 +173,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -173,7 +173,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
} }
net->AddOp(grad_op); net->AddOp(grad_op);
} }
net->type_ = "@GENERATED_BACKWARD@"; net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp(); net->CompleteAddOp();
return net; return net;
} }
......
...@@ -173,8 +173,8 @@ TEST(Backward, simple_op_grad) { ...@@ -173,8 +173,8 @@ TEST(Backward, simple_op_grad) {
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ(1UL, gop->Inputs().size());
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->Type());
ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X")));
ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b")));
} }
...@@ -210,13 +210,13 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -210,13 +210,13 @@ TEST(Backward, net_fc_backward_normal) {
ASSERT_EQ(3UL, net->ops_.size()); ASSERT_EQ(3UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0]; f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
f::OperatorBase &d_add = *net->ops_[1]; f::OperatorBase &d_add = *net->ops_[1];
ASSERT_EQ("rowwise_add_grad", d_add.type_); ASSERT_EQ("rowwise_add_grad", d_add.Type());
f::OperatorBase &d_mul = *net->ops_[2]; f::OperatorBase &d_mul = *net->ops_[2];
ASSERT_EQ("mul_grad", d_mul.type_); ASSERT_EQ("mul_grad", d_mul.Type());
} }
TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_fc_backward_not_have_b) {
...@@ -236,10 +236,10 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -236,10 +236,10 @@ TEST(Backward, net_fc_backward_not_have_b) {
ASSERT_EQ(2UL, net->ops_.size()); ASSERT_EQ(2UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0]; f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
f::OperatorBase &d_mul = *net->ops_[1]; f::OperatorBase &d_mul = *net->ops_[1];
ASSERT_EQ("mul_grad", d_mul.type_); ASSERT_EQ("mul_grad", d_mul.Type());
} }
TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_input_of_network_not_need_grad) {
...@@ -293,7 +293,7 @@ TEST(Backward, net_shared_weight) { ...@@ -293,7 +293,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->type_); ASSERT_EQ("add", bwd_net->ops_[2]->Type());
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
...@@ -334,15 +334,15 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -334,15 +334,15 @@ TEST(Backward, op_part_of_output_are_not_need) {
ASSERT_EQ(net->ops_.size(), 2UL); ASSERT_EQ(net->ops_.size(), 2UL);
auto &fill_zero = *net->ops_[0]; auto &fill_zero = *net->ops_[0];
ASSERT_EQ("fill_zeros_like", fill_zero.type_); ASSERT_EQ("fill_zeros_like", fill_zero.Type());
ASSERT_EQ(1UL, fill_zero.Inputs("Src").size()); ASSERT_EQ(1UL, fill_zero.Inputs("Src").size());
ASSERT_EQ("Z", fill_zero.Input("Src")); ASSERT_EQ("Z", fill_zero.Input("Src"));
ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size()); ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size());
ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst")); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst"));
auto &d_many_out = *net->ops_[1]; auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_); ASSERT_EQ("many_output_op_grad", d_many_out.Type());
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.Inputs().size()); // I/O/OG
ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
d_many_out.Input(f::GradVarName("z"))); d_many_out.Input(f::GradVarName("z")));
ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y"))); ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
...@@ -354,9 +354,9 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -354,9 +354,9 @@ TEST(Backward, op_part_of_input_are_not_need) {
{{"Out", {"out"}}}, {}); {{"Out", {"out"}}}, {});
auto backward = f::Backward(*fwd, {"a"}); auto backward = f::Backward(*fwd, {"a"});
auto &grad_mul = *backward; auto &grad_mul = *backward;
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.Type(), "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.Inputs().size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.Outputs().size(), 2UL);
ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName); ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName);
ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b")); ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b"));
ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out")); ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
...@@ -394,18 +394,18 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -394,18 +394,18 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
const char *all = paddle::operators::NetOp::kAll; const char *all = paddle::operators::NetOp::kAll;
EXPECT_EQ(grad_fc.inputs_[all].size(), EXPECT_EQ(grad_fc.Inputs(all).size(),
2UL /* external input number */ 2UL /* external input number */
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/); + 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_[all].size(), EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/ 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add + 2UL /* input number of rowwise_add
*/ */
+ 1UL /* input number of sigmod */); + 1UL /* input number of sigmod */);
EXPECT_EQ(bwd_net->ops_[1]->inputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_[all].size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL);
} }
...@@ -121,6 +121,7 @@ class OperatorBase { ...@@ -121,6 +121,7 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
const std::string& Type() const { return type_; } const std::string& Type() const { return type_; }
void SetType(const std::string& type) { type_ = type; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册