提交 d6e03682 编写于 作者: Y Yu Yang

Add comment in backward.cc

上级 29d50ad9
...@@ -50,50 +50,72 @@ static std::shared_ptr<OperatorBase> EmptyOp() { ...@@ -50,50 +50,72 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
return net_op; return net_op;
} }
/**
* @brief Backward an operator, implementation
* @param forwardOp the forward operator
* @param no_grad_names variable names not calculate for gradient. Like X@GRAD
* is not needed.
* @param uniq_id a unique index used inside BackwardImpl, it will be shared
* through recursive invoke.
* @return The backward operator. For simple situation, it is a simple operator.
* For complex situation, it is a NetOp.
*
* See Backward.h for details
*/
static std::shared_ptr<OperatorBase> BackwardImpl( static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
/**
* If all input gradients of forwarding operator do not need to calculate,
* just return an EmptyOp. Not return null ptr because EmptyOp does not take
* too much time for calculation, but it is useful for simplifying logic.
*/
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
return EmptyOp(); return EmptyOp();
} }
/**
* All output gradients of forwarding operator do not need to calculate. Then
* all input gradients cannot be computed at all, and we put them into
* `no_grad_names` set. Return an EmptyOp.
*/
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
for (auto& name : forwardOp.inputs_) { for (auto& name : forwardOp.inputs_) {
// Mark all input is not need /// Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
} }
return EmptyOp(); return EmptyOp();
} }
//! Returned gradient network
auto net = std::make_shared<NetOp>(); auto net = std::make_shared<NetOp>();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
//! TODO(dzh) /// Because forwardOp is a net op, it can static_cast.
std::unordered_map<std::string /*var name*/,
std::vector<size_t> /*op offset*/>
dup_output_ops;
size_t local_op_id = 0;
// Because it is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op //! Map from output gradient variable name to operator's indices in backward
//! net. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
size_t local_op_id = 0;
/// reversely travel forwardNet
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it) { ++it, ++local_op_id) {
auto fwd = *it; auto fwd = *it;
auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd); net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) { for (auto& out : bwd->outputs_) {
dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
} }
local_op_id++;
} }
// unique the duplicate name /// Get unique ID for this method.
auto uid = uniq_id++; auto uid = uniq_id++;
// TODO(dzh): more comment // TODO(dzh): more comment
typedef std::pair<size_t, std::shared_ptr<OperatorBase>> Pos; using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
std::list<Pos> insert_postion; std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) { for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first; const std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second; auto& dup_op = dup_output_op.second;
...@@ -106,16 +128,18 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -106,16 +128,18 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
std::to_string(i)); std::to_string(i));
net->ops_[op_offset]->Rename(name, dup_outputs.back()); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
insert_postion.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(),
OpRegistry::CreateOp( OpRegistry::CreateOp(
"add", {dup_outputs}, {name}, "add", {dup_outputs}, {name},
{{"input_format", {{"input_format",
std::vector<int>{0, (int)dup_outputs.size()}}})}); std::vector<int>{0, (int)dup_outputs.size()}}})});
} }
insert_postion.sort(
insert_position.sort(
[](const Pos& l, const Pos& r) { return l.first > r.first; }); [](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_postion) {
for (auto& pos : insert_position) {
net->InsertOp(pos.first, pos.second); net->InsertOp(pos.first, pos.second);
} }
...@@ -148,6 +172,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -148,6 +172,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return net; return net;
} }
//! See header for comments
extern std::shared_ptr<OperatorBase> Backward( extern std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册