backward.cc 13.6 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/framework/backward.h"
D
dongzhihong 已提交
16

F
fengjiayi 已提交
17
#include <deque>
D
dongzhihong 已提交
18
#include <list>
Y
Yu Yang 已提交
19 20
#include <memory>

F
fengjiayi 已提交
21
#include "paddle/framework/block_desc.h"
22
#include "paddle/framework/op_registry.h"
Y
Yan Chunwei 已提交
23
#include "paddle/operators/net_op.h"
Y
Yan Chunwei 已提交
24
#include "paddle/operators/recurrent_op.h"
Y
Yu Yang 已提交
25 26 27 28

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
29
template <typename Map, typename T>
Q
qiaolongfei 已提交
30
static void ForEachVarName(const Map& names, T callback) {
Y
Yu Yang 已提交
31
  for (auto& name : names) {
Y
Yu Yang 已提交
32
    for (auto& n : name.second) {
33
      if (callback(n)) return;
Y
Yu Yang 已提交
34 35
    }
  }
Y
Yu Yang 已提交
36 37
}

Y
Yan Chunwei 已提交
38
// return whether all the names + suffixes in the set
Y
Yu Yang 已提交
39
static bool AllInSet(
Y
Yu Yang 已提交
40
    const std::map<std::string, std::vector<std::string>>& names,
Y
Yu Yang 已提交
41
    const std::string& suffix, const std::unordered_set<std::string>& set) {
42 43 44 45
  bool all_in_set = true;
  ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
    all_in_set = set.find(n + suffix) != set.end();
    return !all_in_set;
Y
Yu Yang 已提交
46
  });
47
  return all_in_set;
Y
Yu Yang 已提交
48 49
}

Y
Yu Yang 已提交
50 51
static std::unique_ptr<OperatorBase> NOP() {
  auto net_op = new operators::NetOp();
Q
qiaolongfei 已提交
52
  net_op->SetType("@NOP@");
Y
Yu Yang 已提交
53
  net_op->CompleteAddOp();
Y
Yu Yang 已提交
54
  return std::unique_ptr<OperatorBase>(net_op);
Y
Yu Yang 已提交
55 56
}

Y
Yan Chunwei 已提交
57
//  Get backward operator from a forward operator, a recursive implementation.
Y
Yu Yang 已提交
58 59 60
//
//  no_grad_names the gradient variable names without gradient calculating.
//
61 62 63
//  uniq_id is a unique index used inside recursively calling
//  BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and
//  pass `uniq_id` through recursive calling.
Y
Yu Yang 已提交
64
//
Y
Yan Chunwei 已提交
65 66
//  returns The backward operator. In a simple situation, it may be a simple
//  operator, in a complex situation, it maybe a NetOp.
Y
Yu Yang 已提交
67 68
//
//  See Backward.h for details
Y
Yu Yang 已提交
69
static std::unique_ptr<OperatorBase> BackwardRecursive(
Y
Yu Yang 已提交
70 71
    const OperatorBase& forwardOp,
    std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
Y
Yu Yang 已提交
72 73
  //  If all input gradients of forwarding operator do not need to calculate,
  //  just return an NOP. Not return null ptr because NOP does not take
Q
typo  
qiaolongfei 已提交
74
  //  too much time for calculation, but it is useful for simplifying logic.
75
  if (AllInSet(forwardOp.Inputs() /*names*/, kGradVarSuffix /*suffix*/,
Y
Yan Chunwei 已提交
76
               no_grad_names /*set*/)) {
Y
Yu Yang 已提交
77
    return NOP();
Y
Yu Yang 已提交
78 79
  }

80 81
  //  All output gradients of forwarding operator do not need to calculate.
  //  Then all input gradients cannot be computed at all, and we put them into
Y
Yu Yang 已提交
82
  //  `no_grad_names` set. Return an NOP.
Q
qiaolongfei 已提交
83
  if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
Y
Yan Chunwei 已提交
84
               no_grad_names /*set*/)) {
Q
qiaolongfei 已提交
85
    ForEachVarName(forwardOp.Inputs(),
Y
Yu Yang 已提交
86 87 88 89
                   [&no_grad_names](const std::string& name) -> bool {
                     no_grad_names.insert(GradVarName(name));
                     return false;
                   });
Y
Yu Yang 已提交
90
    return NOP();
Y
Yu Yang 已提交
91 92
  }

Y
Yu Yang 已提交
93
  // Returned gradient network
Y
Yu Yang 已提交
94
  auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
Y
Yu Yang 已提交
95 96

  if (forwardOp.IsNetOp()) {
Y
Yu Yang 已提交
97
    // Because forwardOp is a net op, it can static_cast.
Y
Yan Chunwei 已提交
98
    auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
Y
Yu Yang 已提交
99

100
    // Map from output gradient variable name to operator's indices in
Y
Yan Chunwei 已提交
101
    // backward net's ops_. That operator generates that variable.
Y
Yu Yang 已提交
102 103 104
    std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;

    size_t local_op_id = 0;
Y
Yan Chunwei 已提交
105
    // reversely travel forwardNet and collect all duplicate outputs.
Y
Yu Yang 已提交
106
    for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
Y
Yu Yang 已提交
107
         ++it, ++local_op_id) {
Y
Yu Yang 已提交
108
      auto& fwd = *it;
Y
Yu Yang 已提交
109
      auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
Q
qiaolongfei 已提交
110
      ForEachVarName(bwd->Outputs(),
Y
Yu Yang 已提交
111 112 113 114
                     [&dup_output_ops, local_op_id](const std::string& out) {
                       dup_output_ops[out].emplace_back(local_op_id);
                       return false;
                     });
Y
Yu Yang 已提交
115
      net->AppendOp(std::move(bwd));
D
dongzhihong 已提交
116
    }
Y
Yu Yang 已提交
117
    // Get unique ID for this method.
D
dongzhihong 已提交
118
    auto uid = uniq_id++;
D
dongzhihong 已提交
119
    // TODO(dzh): more comment
Y
Yan Chunwei 已提交
120 121 122 123 124
    // multiple operators which have the same output (y for example) may
    // overwrite the same y variable when backward, special operations are token
    // to handle this case. For each duplicate output, rename it to an alias
    // (original name with a offset), append an `add` op for its operator,
    // and finally sum all the alias variable to the final output variable y.
Y
Yu Yang 已提交
125
    using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
Y
Yu Yang 已提交
126
    std::list<Pos> insert_position;
D
dongzhihong 已提交
127
    for (auto& dup_output_op : dup_output_ops) {
D
dongzhihong 已提交
128
      const std::string& name = dup_output_op.first;
Q
qijun 已提交
129 130 131
      // duplicate @Empty@ don't need to be added
      if (name == kEmptyVarName) continue;

D
dongzhihong 已提交
132
      auto& dup_op = dup_output_op.second;
Y
Yan Chunwei 已提交
133
      // no duplicate output
D
dongzhihong 已提交
134 135
      if (dup_op.size() == 1) continue;

Y
Yan Chunwei 已提交
136 137
      // process the duplicate outputs
      std::vector<std::string> dup_outputs;
D
dongzhihong 已提交
138
      for (size_t i = 0; i < dup_op.size(); ++i) {
Y
Yan Chunwei 已提交
139
        // rename each duplicate output to an alias
D
dongzhihong 已提交
140
        auto op_offset = dup_op[i];
D
dongzhihong 已提交
141 142 143
        dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
                              std::to_string(i));
        net->ops_[op_offset]->Rename(name, dup_outputs.back());
D
dongzhihong 已提交
144
      }
Y
Yan Chunwei 已提交
145
      // collect all the offset to append `add` op for each alias
D
dzhwinter 已提交
146 147 148
      //
      // one variable is shared between multiple operators.
      // insert add operator one by one, then add it to output
D
dongzhihong 已提交
149 150 151 152 153 154 155 156 157 158 159 160
      for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
           ++output_idx) {
        auto insert_add_x = dup_outputs[output_idx];
        auto insert_add_y = dup_outputs[output_idx];
        auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
        // first add op inserted
        if (output_idx == dup_outputs.size() - 2) {
          insert_add_out = name;
        }
        if (output_idx != 0) {
          insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1);
        }
D
dzhwinter 已提交
161 162 163
        insert_position.push_back(
            {dup_op.back(),
             OpRegistry::CreateOp(
D
dongzhihong 已提交
164
                 "sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
D
dongzhihong 已提交
165
                 {{"Out", {insert_add_out}}}, {})});
D
dzhwinter 已提交
166
      }
D
dongzhihong 已提交
167
    }
Y
Yu Yang 已提交
168

Y
Yan Chunwei 已提交
169
    // make sure the inserted `add` ops follow the BFS order.
Y
Yu Yang 已提交
170
    insert_position.sort(
D
dongzhihong 已提交
171
        [](const Pos& l, const Pos& r) { return l.first > r.first; });
Y
Yu Yang 已提交
172 173

    for (auto& pos : insert_position) {
Y
Yu Yang 已提交
174
      net->InsertOp(pos.first + 1, std::move(pos.second));
D
dongzhihong 已提交
175
    }
Y
Yu Yang 已提交
176
  } else {
Y
Yu Yang 已提交
177
    std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
Y
Yu Yang 已提交
178

Y
Yu Yang 已提交
179 180
    ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
                                          const std::string& grad_input) {
181
      if (no_grad_names.count(grad_input)) {
Y
Yu Yang 已提交
182
        // +1 for \0
183
        std::string prefix = grad_input.substr(
Y
Yu Yang 已提交
184
            0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
Q
qiaolongfei 已提交
185
        grad_op->Rename(grad_input, prefix + kZeroVarSuffix);
Y
Yu Yang 已提交
186 187 188

        // If part of input gradient of that operator is not calculated, fill
        // zero variables to that input gradient.
D
dangqingqing 已提交
189 190
        net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", {{"X", {prefix}}},
                                           {{"Y", {grad_input}}}, {}));
191
      }
Y
Yu Yang 已提交
192 193 194
      return false;
    });

Q
qiaolongfei 已提交
195 196
    ForEachVarName(grad_op->Outputs(),
                   [&no_grad_names, &grad_op](const std::string& grad_output) {
Y
Yu Yang 已提交
197
                     if (no_grad_names.count(grad_output)) {
Q
qiaolongfei 已提交
198
                       grad_op->Rename(grad_output, kEmptyVarName);
Y
Yu Yang 已提交
199 200 201
                     }
                     return false;
                   });
Y
Yu Yang 已提交
202

Y
Yan Chunwei 已提交
203
    // process recurrent gradient op as a special operator.
204
    if (forwardOp.Type() == "recurrent") {
Y
Yan Chunwei 已提交
205 206 207 208 209 210 211 212 213 214
      // NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
      // this will result in infinite loop.
      const auto& rnnop =
          *static_cast<const operators::RecurrentOp*>(&forwardOp);
      auto rnn_grad_op =
          static_cast<operators::RecurrentGradientOp*>(grad_op.get());
      const auto& stepnet_op =
          *static_cast<const OperatorBase*>(&rnnop.stepnet());
      // create stepnet's gradient op
      rnn_grad_op->set_stepnet(
Y
Yu Yang 已提交
215
          BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
Y
Yan Chunwei 已提交
216 217
    }

Y
Yu Yang 已提交
218 219 220
    if (net->ops_.empty()) {  // Current no aux op is added to network
      return grad_op;
    }
Y
Yu Yang 已提交
221
    net->AppendOp(std::move(grad_op));
Y
Yu Yang 已提交
222
  }
Q
qiaolongfei 已提交
223
  net->SetType("@GENERATED_BACKWARD@");
Y
Yu Yang 已提交
224
  net->CompleteAddOp();
Y
Yu Yang 已提交
225 226 227
  return std::unique_ptr<OperatorBase>(
      static_cast<OperatorBase*>(net.release()));
}
Y
Yu Yang 已提交
228

Y
Yu Yang 已提交
229
// See header for comments
Y
Yu Yang 已提交
230
std::unique_ptr<OperatorBase> Backward(
Y
Yu Yang 已提交
231
    const OperatorBase& forwardOp,
Y
Yu Yang 已提交
232 233
    const std::unordered_set<std::string>& no_grad_vars) {
  std::unordered_set<std::string> no_grad_names;
Q
qijun 已提交
234
  no_grad_names.reserve(no_grad_vars.size() + 1);
Y
Yu Yang 已提交
235

236
  no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
237

Y
Yu Yang 已提交
238
  for (auto& name : no_grad_vars) {
239
    no_grad_names.insert(name + kGradVarSuffix);
Y
Yu Yang 已提交
240
  }
Y
Yu Yang 已提交
241
  size_t uid = 0;
Y
Yu Yang 已提交
242
  return BackwardRecursive(forwardOp, no_grad_names, uid);
Y
Yu Yang 已提交
243
}
Y
Yi Wang 已提交
244

F
fengjiayi 已提交
245 246 247 248 249 250 251 252 253 254 255 256
// ====================================  //

static bool AllGradInSet(const std::vector<std::string>& names,
                         const std::unordered_set<std::string>& set) {
  for (const std::string& name : names) {
    if (!set.count(GradVarName(name))) {
      return false;
    }
  }
  return true;
}

F
Update  
fengjiayi 已提交
257 258
std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
    const std::unique_ptr<OpDescBind>& op_desc,
F
fengjiayi 已提交
259
    std::unordered_set<std::string>& no_grad_vars) {
F
Update  
fengjiayi 已提交
260
  std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
F
fengjiayi 已提交
261
  // All input gradients of forwarding operator do not need to calculat.
F
fengjiayi 已提交
262
  if (AllGradInSet(op_desc->InputArgumentNames(), no_grad_vars)) {
F
fengjiayi 已提交
263 264 265
    return grad_op_descs;  // empty vector
  }
  // All output gradients of forwarding operator do not need to calculate.
F
fengjiayi 已提交
266 267
  const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
  if (AllGradInSet(outputs, no_grad_vars)) {
F
fengjiayi 已提交
268 269 270 271 272 273
    for (const std::string& name : outputs) {
      no_grad_vars.insert(GradVarName(name));
    }
    return grad_op_descs;  // empty vector
  }

F
fengjiayi 已提交
274
  grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc);
F
fengjiayi 已提交
275

F
Update  
fengjiayi 已提交
276 277 278
  std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
  for (auto& desc : grad_op_descs) {
    for (const std::string& in_name : desc->InputArgumentNames()) {
F
fengjiayi 已提交
279 280 281 282
      if (no_grad_vars.count(in_name)) {
        std::string prefix = in_name.substr(
            0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
        std::string new_name = prefix + kZeroVarSuffix;
F
Update  
fengjiayi 已提交
283
        desc->Rename(in_name, new_name);
F
fengjiayi 已提交
284 285 286
        std::unique_ptr<OpDescBind> fill_zeros_op(new OpDescBind(
            "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}));
        pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
F
fengjiayi 已提交
287 288
      }
    }
F
fengjiayi 已提交
289
    for (const std::string& out_name : desc->OutputArgumentNames()) {
F
fengjiayi 已提交
290
      if (no_grad_vars.count(out_name)) {
F
Update  
fengjiayi 已提交
291
        desc->Rename(out_name, kEmptyVarName);
F
fengjiayi 已提交
292 293 294
      }
    }
  }
F
fengjiayi 已提交
295 296 297
  for (auto& p : pending_fill_zeros_ops) {
    grad_op_descs.push_back(std::move(p));
  }
F
fengjiayi 已提交
298

F
fengjiayi 已提交
299
  // TODO(fengjiayi): RNN op
F
fengjiayi 已提交
300 301 302
  return grad_op_descs;
}

F
fengjiayi 已提交
303 304
void AppendBackwardOpDescs(BlockDescBind& block_desc,
                           std::unordered_set<std::string>& no_grad_vars) {
F
Update  
fengjiayi 已提交
305 306
  std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
  size_t grad_desc_idx = 0;
F
fengjiayi 已提交
307
  std::deque<std::unique_ptr<OpDescBind>>& block_op_descs = block_desc.ops_;
F
Update  
fengjiayi 已提交
308 309 310 311 312
  std::vector<std::unique_ptr<OpDescBind>> backward_descs;
  for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) {
    std::vector<std::unique_ptr<OpDescBind>> op_grads =
        MakeGradOpDescs(*it, no_grad_vars);
    for (const auto& desc : op_grads) {
F
fengjiayi 已提交
313
      for (const std::string& out_name : desc->OutputArgumentNames()) {
F
Update  
fengjiayi 已提交
314 315 316 317
        dup_out_ops[out_name].emplace_back(grad_desc_idx);
      }
      ++grad_desc_idx;
    }
F
fengjiayi 已提交
318 319 320
    std::transform(
        op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
        [](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
F
Update  
fengjiayi 已提交
321 322
  }
  // Check whether some variables are written more than once
F
Update  
fengjiayi 已提交
323
  std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
F
Update  
fengjiayi 已提交
324 325 326 327 328 329 330
  for (const auto& dup : dup_out_ops) {
    const std::string& out_name = dup.first;
    const std::vector<size_t> dup_op = dup.second;
    if (out_name != kEmptyVarName && dup_op.size() > 1) {
      std::vector<std::string> sum_op_inputs;
      for (size_t i = 0; i < dup_op.size(); ++i) {
        std::string new_name = out_name + "@RENAME@" + std::to_string(i);
F
Update  
fengjiayi 已提交
331
        backward_descs[dup_op[i]]->Rename(out_name, new_name);
F
Update  
fengjiayi 已提交
332 333
        sum_op_inputs.emplace_back(new_name);
      }
F
fengjiayi 已提交
334 335 336
      std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
          "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
      pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
F
Update  
fengjiayi 已提交
337 338 339
    }
  }
  pending_sum_ops.sort(
F
Update  
fengjiayi 已提交
340 341 342 343
      [](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
         const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
        return a.first > b.first;
      });
F
Update  
fengjiayi 已提交
344
  for (auto& p : pending_sum_ops) {
F
Update  
fengjiayi 已提交
345 346
    backward_descs.insert(backward_descs.begin() + p.first + 1,
                          std::move(p.second));
F
Update  
fengjiayi 已提交
347
  }
F
Update  
fengjiayi 已提交
348
  // Append backward_descs to BlockDescBind::ops_
F
fengjiayi 已提交
349 350 351
  for (std::unique_ptr<OpDescBind>& ptr : backward_descs) {
    block_op_descs.push_back(std::move(ptr));
  }
F
Update  
fengjiayi 已提交
352
  return;
F
Update  
fengjiayi 已提交
353 354
}

Y
Yu Yang 已提交
355 356
}  // namespace framework
}  // namespace paddle