提交 991c4d80 编写于 作者: Y Yan Chunwei 提交者: GitHub

add some doc to backward (#3474)

上级 4b4f7457
...@@ -30,6 +30,7 @@ static void ForEachVarName(Map& names, T callback) { ...@@ -30,6 +30,7 @@ static void ForEachVarName(Map& names, T callback) {
} }
} }
// return whether all the names + suffixes in the set
static bool AllInSet( static bool AllInSet(
const std::map<std::string, std::vector<std::string>>& names, const std::map<std::string, std::vector<std::string>>& names,
const std::string& suffix, const std::unordered_set<std::string>& set) { const std::string& suffix, const std::unordered_set<std::string>& set) {
...@@ -48,7 +49,7 @@ static std::shared_ptr<OperatorBase> NOP() { ...@@ -48,7 +49,7 @@ static std::shared_ptr<OperatorBase> NOP() {
return net_op; return net_op;
} }
// Get backward operator from a forward operator, recursively implementation. // Get backward operator from a forward operator, a recursive implementation.
// //
// no_grad_names the gradient variable names without gradient calculating. // no_grad_names the gradient variable names without gradient calculating.
// //
...@@ -56,27 +57,30 @@ static std::shared_ptr<OperatorBase> NOP() { ...@@ -56,27 +57,30 @@ static std::shared_ptr<OperatorBase> NOP() {
// BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and // BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and
// pass `uniq_id` through recursive calling. // pass `uniq_id` through recursive calling.
// //
// returns The backward operator. For simple situation, it is a simple // returns The backward operator. In a simple situation, it may be a simple
// operator. For complex situation, it is a NetOp. // operator, in a complex situation, it maybe a NetOp.
// //
// See Backward.h for details // See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive( static std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id); std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
std::shared_ptr<OperatorBase> BackwardRecursive( std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic. // much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) { if (AllInSet(forwardOp.inputs_ /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) {
return NOP(); return NOP();
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) { if (AllInSet(forwardOp.outputs_ /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) {
ForEachVarName(forwardOp.inputs_, ForEachVarName(forwardOp.inputs_,
[&no_grad_names](const std::string& name) -> bool { [&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name)); no_grad_names.insert(GradVarName(name));
...@@ -93,11 +97,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -93,11 +97,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp); auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
// Map from output gradient variable name to operator's indices in // Map from output gradient variable name to operator's indices in
// backward net. That operator generates that variable. // backward net's ops_. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
size_t local_op_id = 0; size_t local_op_id = 0;
// reversely travel forwardNet // reversely travel forwardNet and collect all duplicate outputs.
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) { ++it, ++local_op_id) {
auto fwd = *it; auto fwd = *it;
...@@ -112,25 +116,35 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -112,25 +116,35 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// Get unique ID for this method. // Get unique ID for this method.
auto uid = uniq_id++; auto uid = uniq_id++;
// TODO(dzh): more comment // TODO(dzh): more comment
// multiple operators which have the same output (y for example) may
// overwrite the same y variable when backward, special operations are token
// to handle this case. For each duplicate output, rename it to an alias
// (original name with a offset), append an `add` op for its operator,
// and finally sum all the alias variable to the final output variable y.
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>; using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
std::list<Pos> insert_position; std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) { for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first; const std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second; auto& dup_op = dup_output_op.second;
// no duplicate output
if (dup_op.size() == 1) continue; if (dup_op.size() == 1) continue;
std::vector<std::string> dup_outputs;
// process the duplicate outputs
std::vector<std::string> dup_outputs;
for (size_t i = 0; i < dup_op.size(); ++i) { for (size_t i = 0; i < dup_op.size(); ++i) {
// rename each duplicate output to an alias
auto op_offset = dup_op[i]; auto op_offset = dup_op[i];
dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
std::to_string(i)); std::to_string(i));
net->ops_[op_offset]->Rename(name, dup_outputs.back()); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
// collect all the offset to append `add` op for each alias
insert_position.push_back( insert_position.push_back(
{dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}}, {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
{{"Out", {name}}}, {})}); {{"Out", {name}}}, {})});
} }
// make sure the inserted `add` ops follow the BFS order.
insert_position.sort( insert_position.sort(
[](const Pos& l, const Pos& r) { return l.first > r.first; }); [](const Pos& l, const Pos& r) { return l.first > r.first; });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册