BackwardRecursive(
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
- net->AddOp(std::move(grad_op));
+ net->AppendOp(std::move(grad_op));
}
net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp();
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 74c001b06a9e7b2279abf998604f2acf1b1168e4..c8fa3fefe5632a36d9044b4bccfd3dbb7c64dbf6 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -21,18 +21,32 @@ grad_op_builder(fengjiayi)
given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
-1. bla bla bla (yuyang)
+1. Op
+
+ when the input forward network is a Op, return its gradient Operator Immediately.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively and ensure them done. During the process, we need to collect the `OutputGradients` name.
+ when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+
+ **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+
+
+ 
+
+ 1. shared variable in two operators.
+
+
+
+ Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+
+
+ 
- We share variable in the same scope, as a result, duplicate operator `OutputGradients` will overwirte then duplicate variable.
+ 2. replace shared variable gradient with `Add` Operator
- ![./images/duplicate_op]()
+
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator instead.
-![./images/duplicate_op2]()
- Then collect the sub graph OutputGradients/InputGradients as the NetOp's and return it.
+ Then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 2c5ec76dfeb8b8485951e4d94896b6758e0cb930..f100c4d05489ac3bd4ceb5f11ae871985f0e5d83 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -72,16 +72,16 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class FcOp : public operators::NetOp {
public:
- FcOp(const std::string &type, const VarNameMap &inputs,
- const VarNameMap &outputs, const AttributeMap &attrs)
+ FcOp(const std::string &type, const VariableNameMap &inputs,
+ const VariableNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
- AddOp(OpRegistry::CreateOp("mul",
- {{"X", {Input("X")}}, {"Y", {Input("W")}}},
- {{"Out", {Output("mul_result")}}}, {}));
+ AppendOp(OpRegistry::CreateOp("mul",
+ {{"X", {Input("X")}}, {"Y", {Input("W")}}},
+ {{"Out", {Output("mul_result")}}}, {}));
auto input_b = Inputs("b");
std::string before_act = "mul_result";
if (input_b.size() != 0) {
- AddOp(OpRegistry::CreateOp(
+ AppendOp(OpRegistry::CreateOp(
"rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
{{"Out", {Output("add_result")}}}, {}));
before_act = "add_result";
@@ -92,8 +92,8 @@ class FcOp : public operators::NetOp {
}
}
- AddOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
- {{"Out", {Output("Out")}}}, {}));
+ AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
+ {{"Out", {Output("Out")}}}, {}));
CompleteAddOp(false);
}
};
@@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST(Backward, net_input_of_network_not_need_grad) {
ops::NetOp net;
- net.AddOp(f::OpRegistry::CreateOp(
+ net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}},
{{"mul_result", {"mul_tmp_0"}},
{"add_result", {"add_tmp_0"}},
{"Out", {"hidden0"}}},
{}));
- net.AddOp(f::OpRegistry::CreateOp(
+ net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_tmp_1"}},
{"add_result", {"add_tmp_1"}},
@@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) {
TEST(Backward, net_shared_weight) {
ops::NetOp net;
- net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
- {{"Out", {"out"}}}, {}));
- net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
- {{"Out", {"FinalOut"}}}, {}));
+ net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
+ {{"Out", {"out"}}}, {}));
+ net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
+ {{"Out", {"FinalOut"}}}, {}));
net.CompleteAddOp();
auto bwd = f::Backward(net, {});
@@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) {
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
ops::NetOp net;
- net.AddOp(f::OpRegistry::CreateOp(
+ net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"mul_result", {"mul_out1"}},
{"add_result", {"add_out1"}},
{"Out", {"out1"}}},
{}));
- net.AddOp(f::OpRegistry::CreateOp(
+ net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_out2"}},
{"add_result", {"tmp_out2"}},
{"Out", {"out2"}}},
{}));
- net.AddOp(f::OpRegistry::CreateOp(
+ net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}},
{{"mul_result", {"mul_out3"}},
{"add_result", {"tmp_out3"}},
diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc
index 0a2a41f6b62658ac8633a6e384d099f8d6641f33..b02a599a800668b22e7fe39a10fa6dc132e305bd 100644
--- a/paddle/framework/grad_op_builder.cc
+++ b/paddle/framework/grad_op_builder.cc
@@ -20,13 +20,13 @@ namespace framework {
enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
- bool is_grad, OperatorBase::VarNameMap* vars) {
+ bool is_grad, VariableNameMap* vars) {
const auto& src_inout =
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
auto& dst_inout = *vars;
- const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_;
+ auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
const auto& src_arg_list =
- src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
+ src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
@@ -40,26 +40,18 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
}
OperatorBase* BuildGradOp(const OperatorBase* op) {
- auto it = OpRegistry::op_info_map().find(op->Type());
- PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
- "'%s' has not been registered.", op->Type());
- PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
- op->Type());
- std::string grad_op_type = it->second.grad_op_type_;
- PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
- op->Type());
+ auto& info = OpInfoMap::Instance().Get(op->Type());
+ PADDLE_ENFORCE(info.HasGradientOp());
- OperatorBase::VarNameMap inputs;
- OperatorBase::VarNameMap outputs;
+ VariableNameMap inputs;
+ VariableNameMap outputs;
TransOpArg(op, OpArgType::IN, false, &inputs); // I
TransOpArg(op, OpArgType::OUT, false, &inputs); // O
TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
TransOpArg(op, OpArgType::IN, true, &outputs); // IG
- it = OpRegistry::op_info_map().find(grad_op_type);
- PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
- "'%s' has not been registered.", grad_op_type);
- return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs());
+ auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_);
+ return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
}
} // namespace framework
diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc
new file mode 100644
index 0000000000000000000000000000000000000000..81ba29797c5f478e5d6a91236f3e8de1e6b43e49
--- /dev/null
+++ b/paddle/framework/op_info.cc
@@ -0,0 +1,29 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/op_info.h"
+
+namespace paddle {
+namespace framework {
+
+static OpInfoMap* g_op_info_map = nullptr;
+
+OpInfoMap& OpInfoMap::Instance() {
+ if (g_op_info_map == nullptr) {
+ g_op_info_map = new OpInfoMap();
+ }
+ return *g_op_info_map;
+}
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h
new file mode 100644
index 0000000000000000000000000000000000000000..94245c6c44aca962b0db890947a9dc5550ac0799
--- /dev/null
+++ b/paddle/framework/op_info.h
@@ -0,0 +1,101 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+#include
+#include