backward.cc 7.8 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/framework/backward.h"
D
dongzhihong 已提交
16

D
dongzhihong 已提交
17
#include <list>
18
#include "paddle/framework/op_registry.h"
Y
Yan Chunwei 已提交
19
#include "paddle/operators/net_op.h"
Y
Yu Yang 已提交
20 21 22 23

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
24 25
template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) {
Y
Yu Yang 已提交
26
  for (auto& name : names) {
Y
Yu Yang 已提交
27
    for (auto& n : name.second) {
28
      if (callback(n)) return;
Y
Yu Yang 已提交
29 30
    }
  }
Y
Yu Yang 已提交
31 32
}

Y
Yan Chunwei 已提交
33
// return whether all the names + suffixes in the set
Y
Yu Yang 已提交
34
static bool AllInSet(
Y
Yu Yang 已提交
35
    const std::map<std::string, std::vector<std::string>>& names,
Y
Yu Yang 已提交
36
    const std::string& suffix, const std::unordered_set<std::string>& set) {
37 38 39 40
  bool all_in_set = true;
  ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
    all_in_set = set.find(n + suffix) != set.end();
    return !all_in_set;
Y
Yu Yang 已提交
41
  });
42
  return all_in_set;
Y
Yu Yang 已提交
43 44
}

Y
Yu Yang 已提交
45
static std::shared_ptr<OperatorBase> NOP() {
Y
Yan Chunwei 已提交
46
  auto net_op = std::make_shared<operators::NetOp>();
Y
Yu Yang 已提交
47
  net_op->type_ = "@NOP@";
Y
Yu Yang 已提交
48 49 50 51
  net_op->CompleteAddOp();
  return net_op;
}

Y
Yan Chunwei 已提交
52
//  Get backward operator from a forward operator, a recursive implementation.
Y
Yu Yang 已提交
53 54 55
//
//  no_grad_names the gradient variable names without gradient calculating.
//
56 57 58
//  uniq_id is a unique index used inside recursively calling
//  BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and
//  pass `uniq_id` through recursive calling.
Y
Yu Yang 已提交
59
//
Y
Yan Chunwei 已提交
60 61
//  returns The backward operator. In a simple situation, it may be a simple
//  operator, in a complex situation, it maybe a NetOp.
Y
Yu Yang 已提交
62 63 64 65 66
//
//  See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive(
    const OperatorBase& forwardOp,
    std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
Y
Yan Chunwei 已提交
67

Y
Yu Yang 已提交
68
std::shared_ptr<OperatorBase> BackwardRecursive(
Y
Yu Yang 已提交
69 70
    const OperatorBase& forwardOp,
    std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
Y
Yu Yang 已提交
71 72
  //  If all input gradients of forwarding operator do not need to calculate,
  //  just return an NOP. Not return null ptr because NOP does not take
Y
Yan Chunwei 已提交
73 74 75
  //  much time for calculation, but it is useful for simplifying logic.
  if (AllInSet(forwardOp.inputs_ /*names*/, kGradVarSuffix /*suffix*/,
               no_grad_names /*set*/)) {
Y
Yu Yang 已提交
76
    return NOP();
Y
Yu Yang 已提交
77 78
  }

79 80
  //  All output gradients of forwarding operator do not need to calculate.
  //  Then all input gradients cannot be computed at all, and we put them into
Y
Yu Yang 已提交
81
  //  `no_grad_names` set. Return an NOP.
Y
Yan Chunwei 已提交
82 83
  if (AllInSet(forwardOp.outputs_ /*names*/, kGradVarSuffix /*suffix*/,
               no_grad_names /*set*/)) {
Y
Yu Yang 已提交
84 85 86 87 88
    ForEachVarName(forwardOp.inputs_,
                   [&no_grad_names](const std::string& name) -> bool {
                     no_grad_names.insert(GradVarName(name));
                     return false;
                   });
Y
Yu Yang 已提交
89
    return NOP();
Y
Yu Yang 已提交
90 91
  }

Y
Yu Yang 已提交
92
  // Returned gradient network
Y
Yan Chunwei 已提交
93
  auto net = std::make_shared<operators::NetOp>();
Y
Yu Yang 已提交
94 95

  if (forwardOp.IsNetOp()) {
Y
Yu Yang 已提交
96
    // Because forwardOp is a net op, it can static_cast.
Y
Yan Chunwei 已提交
97
    auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
Y
Yu Yang 已提交
98

99
    // Map from output gradient variable name to operator's indices in
Y
Yan Chunwei 已提交
100
    // backward net's ops_. That operator generates that variable.
Y
Yu Yang 已提交
101 102 103
    std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;

    size_t local_op_id = 0;
Y
Yan Chunwei 已提交
104
    // reversely travel forwardNet and collect all duplicate outputs.
Y
Yu Yang 已提交
105
    for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
Y
Yu Yang 已提交
106
         ++it, ++local_op_id) {
D
dongzhihong 已提交
107
      auto fwd = *it;
Y
Yu Yang 已提交
108
      auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
D
dongzhihong 已提交
109
      net->AddOp(bwd);
Y
Yu Yang 已提交
110 111 112 113 114
      ForEachVarName(bwd->outputs_,
                     [&dup_output_ops, local_op_id](const std::string& out) {
                       dup_output_ops[out].emplace_back(local_op_id);
                       return false;
                     });
D
dongzhihong 已提交
115
    }
Y
Yu Yang 已提交
116
    // Get unique ID for this method.
D
dongzhihong 已提交
117
    auto uid = uniq_id++;
D
dongzhihong 已提交
118
    // TODO(dzh): more comment
Y
Yan Chunwei 已提交
119 120 121 122 123
    // multiple operators which have the same output (y for example) may
    // overwrite the same y variable when backward, special operations are token
    // to handle this case. For each duplicate output, rename it to an alias
    // (original name with a offset), append an `add` op for its operator,
    // and finally sum all the alias variable to the final output variable y.
Y
Yu Yang 已提交
124 125
    using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
    std::list<Pos> insert_position;
D
dongzhihong 已提交
126
    for (auto& dup_output_op : dup_output_ops) {
D
dongzhihong 已提交
127
      const std::string& name = dup_output_op.first;
D
dongzhihong 已提交
128
      auto& dup_op = dup_output_op.second;
Y
Yan Chunwei 已提交
129
      // no duplicate output
D
dongzhihong 已提交
130 131
      if (dup_op.size() == 1) continue;

Y
Yan Chunwei 已提交
132 133
      // process the duplicate outputs
      std::vector<std::string> dup_outputs;
D
dongzhihong 已提交
134
      for (size_t i = 0; i < dup_op.size(); ++i) {
Y
Yan Chunwei 已提交
135
        // rename each duplicate output to an alias
D
dongzhihong 已提交
136
        auto op_offset = dup_op[i];
D
dongzhihong 已提交
137 138 139
        dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
                              std::to_string(i));
        net->ops_[op_offset]->Rename(name, dup_outputs.back());
D
dongzhihong 已提交
140
      }
Y
Yan Chunwei 已提交
141
      // collect all the offset to append `add` op for each alias
Y
Yu Yang 已提交
142
      insert_position.push_back(
Y
Yu Yang 已提交
143 144
          {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
                                               {{"Out", {name}}}, {})});
D
dongzhihong 已提交
145
    }
Y
Yu Yang 已提交
146

Y
Yan Chunwei 已提交
147
    // make sure the inserted `add` ops follow the BFS order.
Y
Yu Yang 已提交
148
    insert_position.sort(
D
dongzhihong 已提交
149
        [](const Pos& l, const Pos& r) { return l.first > r.first; });
Y
Yu Yang 已提交
150 151

    for (auto& pos : insert_position) {
Y
Yu Yang 已提交
152
      net->InsertOp(pos.first + 1, pos.second);
D
dongzhihong 已提交
153
    }
Y
Yu Yang 已提交
154
  } else {
155
    std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
Y
Yu Yang 已提交
156 157 158

    ForEachVarName(grad_op->inputs_, [&no_grad_names,
                                      &net](std::string& grad_input) {
159
      if (no_grad_names.count(grad_input)) {
Y
Yu Yang 已提交
160
        // +1 for \0
161
        std::string prefix = grad_input.substr(
Y
Yu Yang 已提交
162
            0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
163
        grad_input = prefix + kZeroVarSuffix;
Y
Yu Yang 已提交
164 165 166

        // If part of input gradient of that operator is not calculated, fill
        // zero variables to that input gradient.
Y
Yu Yang 已提交
167 168
        net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}},
                                        {{"Dst", {grad_input}}}, {}));
169
      }
Y
Yu Yang 已提交
170 171 172 173 174 175 176 177 178 179
      return false;
    });

    ForEachVarName(grad_op->outputs_,
                   [&no_grad_names](std::string& grad_output) {
                     if (no_grad_names.count(grad_output)) {
                       grad_output = kEmptyVarName;
                     }
                     return false;
                   });
Y
Yu Yang 已提交
180 181 182 183

    if (net->ops_.empty()) {  // Current no aux op is added to network
      return grad_op;
    }
F
fengjiayi 已提交
184
    net->AddOp(grad_op);
Y
Yu Yang 已提交
185
  }
Y
Yu Yang 已提交
186
  net->type_ = "@GENERATED_BACKWARD@";
Y
Yu Yang 已提交
187
  net->CompleteAddOp();
Y
Yu Yang 已提交
188
  return net;
Y
Yu Yang 已提交
189
}  // namespace framework
Y
Yu Yang 已提交
190

Y
Yu Yang 已提交
191 192
// See header for comments
std::shared_ptr<OperatorBase> Backward(
Y
Yu Yang 已提交
193
    const OperatorBase& forwardOp,
Y
Yu Yang 已提交
194 195 196 197
    const std::unordered_set<std::string>& no_grad_vars) {
  std::unordered_set<std::string> no_grad_names;
  no_grad_names.reserve(no_grad_vars.size());

198
  no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
199

Y
Yu Yang 已提交
200
  for (auto& name : no_grad_vars) {
201
    no_grad_names.insert(name + kGradVarSuffix);
Y
Yu Yang 已提交
202
  }
Y
Yu Yang 已提交
203
  size_t uid = 0;
Y
Yu Yang 已提交
204
  return BackwardRecursive(forwardOp, no_grad_names, uid);
Y
Yu Yang 已提交
205
}
Y
Yi Wang 已提交
206

Y
Yu Yang 已提交
207 208
}  // namespace framework
}  // namespace paddle