提交 fa7cbfde 编写于 作者: D dongzhihong

"backward is NetOp"

上级 b1b13f8f
......@@ -48,9 +48,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
return net_op;
}
static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, int& uniq_id) {
static void DeDuplicate(NetOp* net, std::unordered_se)
static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, unsigned& uniq_id) {
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
return EmptyOp();
......@@ -68,6 +70,38 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
auto* net = new NetOp();
if (forwardOp.IsNetOp()) {
std::unordered_map<std::string, int> dup_output;
std::unordered_map<std::string std::vector<int>> dup_output_ops;
const unsigned uniq_id_local = uniq_id;
unsigned op_id_offset = 0;
for (auto& fwd : forwardOp) {
auto bwd = Backward(fwd, no_grad_names);
net->AddOp(bwd);
for (size_t i = 0; i < bwd.outputs_; ++i) {
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME();
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
dup_output[bwd->inputs_[i]] = 1;
dup_output_ops[bwd->inputs_[i]] = std::vector<int>{op_id_offset++};
} else {
dup_output[bwd->inputs_[i]]++;
dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++);
}
}
}
for (auto dup : dup_output) {
if (dup.second == 1) continue;
auto op_ids = dup_output_ops.at(dup.first);
for (auto& op_id : op_ids) {
auto& op_ptr = net->ops_[op_id];
for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) {
if (op_ptr->inputs_[i] == dup.first) {
// unique the duplicate name
op_ptr->inputs_[i] += std::to_string(uniq_id++);
}
}
}
}
//! TODO(dzh)
} else {
//! TODO(fjy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册