backward.cc 20.7 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/framework/backward.h"
Y
Yu Yang 已提交
16
#include "paddle/operators/net_op.h"
D
dongzhihong 已提交
17

F
fengjiayi 已提交
18
#include <deque>
D
dongzhihong 已提交
19
#include <list>
Y
Yu Yang 已提交
20
#include <memory>
Y
Yu Yang 已提交
21
#include <unordered_set>
Y
Yu Yang 已提交
22

F
fengjiayi 已提交
23
#include "paddle/framework/block_desc.h"
24
#include "paddle/framework/op_registry.h"
Y
Yan Chunwei 已提交
25
#include "paddle/operators/net_op.h"
Y
Yu Yang 已提交
26 27 28 29

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
30
static inline std::unique_ptr<OperatorBase> CreateGradOp(
31 32
    const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
    std::unordered_map<std::string, std::string>* grad_to_var) {
Y
Yu Yang 已提交
33 34 35 36 37 38
  OpDescBind op_desc;
  op_desc.SetInputMap(op.Inputs());
  op_desc.SetOutputMap(op.Outputs());
  op_desc.SetType(op.Type());
  op_desc.SetAttrMap(op.Attrs());
  auto& info = OpInfoMap::Instance().Get(op.Type());
Y
Yu Yang 已提交
39
  auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var, {});
Y
Yu Yang 已提交
40 41
  std::vector<std::unique_ptr<OperatorBase>> grad_ops;
  grad_ops.reserve(grad_descs.size());
Y
Yu Yang 已提交
42 43 44
  std::transform(grad_descs.begin(), grad_descs.end(),
                 std::back_inserter(grad_ops),
                 [](const std::unique_ptr<OpDescBind>& grad_desc) {
Y
Yu Yang 已提交
45
                   return OpRegistry::CreateOp(*grad_desc);
Y
Yu Yang 已提交
46
                 });
Y
Yu Yang 已提交
47
  PADDLE_ENFORCE(!grad_ops.empty());
Y
Yu Yang 已提交
48 49 50 51 52 53 54
  if (grad_ops.size() == 1) {
    return std::move(grad_ops[0]);
  } else {
    auto net_op = new operators::NetOp();
    for (auto& grad_op : grad_ops) {
      net_op->AppendOp(std::move(grad_op));
    }
Y
Yu Yang 已提交
55
    net_op->CompleteAddOp();
Y
Yu Yang 已提交
56 57 58 59
    return std::unique_ptr<OperatorBase>(net_op);
  }
}

Y
Yu Yang 已提交
60
template <typename Map, typename T>
Q
qiaolongfei 已提交
61
static void ForEachVarName(const Map& names, T callback) {
Y
Yu Yang 已提交
62
  for (auto& name : names) {
Y
Yu Yang 已提交
63
    for (auto& n : name.second) {
64
      if (callback(n)) return;
Y
Yu Yang 已提交
65 66
    }
  }
Y
Yu Yang 已提交
67 68
}

Y
Yan Chunwei 已提交
69
// return whether all the names + suffixes in the set
Y
Yu Yang 已提交
70
static bool AllInSet(
Y
Yu Yang 已提交
71
    const std::map<std::string, std::vector<std::string>>& names,
Y
Yu Yang 已提交
72
    const std::string& suffix, const std::unordered_set<std::string>& set) {
73 74 75 76
  bool all_in_set = true;
  ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
    all_in_set = set.find(n + suffix) != set.end();
    return !all_in_set;
Y
Yu Yang 已提交
77
  });
78
  return all_in_set;
Y
Yu Yang 已提交
79 80
}

Y
Yu Yang 已提交
81 82
static std::unique_ptr<OperatorBase> NOP() {
  auto net_op = new operators::NetOp();
Q
qiaolongfei 已提交
83
  net_op->SetType("@NOP@");
Y
Yu Yang 已提交
84
  net_op->CompleteAddOp();
Y
Yu Yang 已提交
85
  return std::unique_ptr<OperatorBase>(net_op);
Y
Yu Yang 已提交
86 87
}

Y
Yan Chunwei 已提交
88
//  Get backward operator from a forward operator, a recursive implementation.
Y
Yu Yang 已提交
89 90 91
//
//  no_grad_names the gradient variable names without gradient calculating.
//
92 93 94
//  uniq_id is a unique index used inside recursively calling
//  BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and
//  pass `uniq_id` through recursive calling.
Y
Yu Yang 已提交
95
//
Y
Yan Chunwei 已提交
96 97
//  returns The backward operator. In a simple situation, it may be a simple
//  operator, in a complex situation, it maybe a NetOp.
Y
Yu Yang 已提交
98 99
//
//  See Backward.h for details
Y
Yu Yang 已提交
100
static std::unique_ptr<OperatorBase> BackwardRecursive(
Y
Yu Yang 已提交
101
    const OperatorBase& forwardOp,
102 103 104
    std::unordered_set<std::string>& no_grad_names,
    std::unordered_map<std::string, std::string>* grad_to_var,
    size_t& uniq_id) {
Y
Yu Yang 已提交
105 106
  //  If all input gradients of forwarding operator do not need to calculate,
  //  just return an NOP. Not return null ptr because NOP does not take
Q
typo  
qiaolongfei 已提交
107
  //  too much time for calculation, but it is useful for simplifying logic.
108
  if (AllInSet(forwardOp.Inputs() /*names*/, kGradVarSuffix /*suffix*/,
Y
Yan Chunwei 已提交
109
               no_grad_names /*set*/)) {
Y
Yu Yang 已提交
110
    return NOP();
Y
Yu Yang 已提交
111 112
  }

113 114
  //  All output gradients of forwarding operator do not need to calculate.
  //  Then all input gradients cannot be computed at all, and we put them into
Y
Yu Yang 已提交
115
  //  `no_grad_names` set. Return an NOP.
Q
qiaolongfei 已提交
116
  if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
Y
Yan Chunwei 已提交
117
               no_grad_names /*set*/)) {
Q
qiaolongfei 已提交
118
    ForEachVarName(forwardOp.Inputs(),
Y
Yu Yang 已提交
119 120 121 122
                   [&no_grad_names](const std::string& name) -> bool {
                     no_grad_names.insert(GradVarName(name));
                     return false;
                   });
Y
Yu Yang 已提交
123
    return NOP();
Y
Yu Yang 已提交
124 125
  }

Y
Yu Yang 已提交
126
  // Returned gradient network
Y
Yu Yang 已提交
127
  auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
Y
Yu Yang 已提交
128 129

  if (forwardOp.IsNetOp()) {
Y
Yu Yang 已提交
130
    // Because forwardOp is a net op, it can static_cast.
Y
Yan Chunwei 已提交
131
    auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
Y
Yu Yang 已提交
132

133
    // Map from output gradient variable name to operator's indices in
Y
Yan Chunwei 已提交
134
    // backward net's ops_. That operator generates that variable.
Y
Yu Yang 已提交
135 136 137
    std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;

    size_t local_op_id = 0;
Y
Yan Chunwei 已提交
138
    // reversely travel forwardNet and collect all duplicate outputs.
Y
Yu Yang 已提交
139
    for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
Y
Yu Yang 已提交
140
         ++it, ++local_op_id) {
Y
Yu Yang 已提交
141
      auto& fwd = *it;
142
      auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id);
Q
qiaolongfei 已提交
143
      ForEachVarName(bwd->Outputs(),
Y
Yu Yang 已提交
144 145 146 147
                     [&dup_output_ops, local_op_id](const std::string& out) {
                       dup_output_ops[out].emplace_back(local_op_id);
                       return false;
                     });
Y
Yu Yang 已提交
148
      net->AppendOp(std::move(bwd));
D
dongzhihong 已提交
149
    }
Y
Yu Yang 已提交
150
    // Get unique ID for this method.
D
dongzhihong 已提交
151
    auto uid = uniq_id++;
D
dongzhihong 已提交
152
    // TODO(dzh): more comment
Y
Yan Chunwei 已提交
153 154 155 156 157
    // multiple operators which have the same output (y for example) may
    // overwrite the same y variable when backward, special operations are token
    // to handle this case. For each duplicate output, rename it to an alias
    // (original name with a offset), append an `add` op for its operator,
    // and finally sum all the alias variable to the final output variable y.
Y
Yu Yang 已提交
158
    using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
Y
Yu Yang 已提交
159
    std::list<Pos> insert_position;
D
dongzhihong 已提交
160
    for (auto& dup_output_op : dup_output_ops) {
D
dongzhihong 已提交
161
      const std::string& name = dup_output_op.first;
Q
qijun 已提交
162 163 164
      // duplicate @Empty@ don't need to be added
      if (name == kEmptyVarName) continue;

D
dongzhihong 已提交
165
      auto& dup_op = dup_output_op.second;
Y
Yan Chunwei 已提交
166
      // no duplicate output
D
dongzhihong 已提交
167 168
      if (dup_op.size() == 1) continue;

Y
Yan Chunwei 已提交
169 170
      // process the duplicate outputs
      std::vector<std::string> dup_outputs;
D
dongzhihong 已提交
171
      for (size_t i = 0; i < dup_op.size(); ++i) {
Y
Yan Chunwei 已提交
172
        // rename each duplicate output to an alias
D
dongzhihong 已提交
173
        auto op_offset = dup_op[i];
D
dongzhihong 已提交
174 175 176
        dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
                              std::to_string(i));
        net->ops_[op_offset]->Rename(name, dup_outputs.back());
D
dongzhihong 已提交
177
      }
178 179 180 181 182
      // collect all the offset for each alias,
      // insert a sum operator to add all aliases to output
      insert_position.push_back(
          {dup_op.back(), OpRegistry::CreateOp("sum", {{"X", dup_outputs}},
                                               {{"Out", {name}}}, {})});
D
dongzhihong 已提交
183
    }
Y
Yu Yang 已提交
184

185
    // make sure the inserted `sum` ops follow the BFS order.
Y
Yu Yang 已提交
186
    insert_position.sort(
D
dongzhihong 已提交
187
        [](const Pos& l, const Pos& r) { return l.first > r.first; });
Y
Yu Yang 已提交
188 189

    for (auto& pos : insert_position) {
Y
Yu Yang 已提交
190
      net->InsertOp(pos.first + 1, std::move(pos.second));
D
dongzhihong 已提交
191
    }
Y
Yu Yang 已提交
192
  } else {
193
    std::unique_ptr<OperatorBase> grad_op(
194
        CreateGradOp(forwardOp, no_grad_names, grad_to_var));
Y
Yu Yang 已提交
195

Y
Yu Yang 已提交
196 197
    ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
                                          const std::string& grad_input) {
198
      if (no_grad_names.count(grad_input)) {
Y
Yu Yang 已提交
199
        // +1 for \0
200
        std::string prefix = grad_input.substr(
Y
Yu Yang 已提交
201
            0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
Q
qiaolongfei 已提交
202
        grad_op->Rename(grad_input, prefix + kZeroVarSuffix);
Y
Yu Yang 已提交
203 204 205

        // If part of input gradient of that operator is not calculated, fill
        // zero variables to that input gradient.
D
dangqingqing 已提交
206 207
        net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", {{"X", {prefix}}},
                                           {{"Y", {grad_input}}}, {}));
208
      }
Y
Yu Yang 已提交
209 210 211
      return false;
    });

Q
qiaolongfei 已提交
212 213
    ForEachVarName(grad_op->Outputs(),
                   [&no_grad_names, &grad_op](const std::string& grad_output) {
Y
Yu Yang 已提交
214
                     if (no_grad_names.count(grad_output)) {
Q
qiaolongfei 已提交
215
                       grad_op->Rename(grad_output, kEmptyVarName);
Y
Yu Yang 已提交
216 217 218
                     }
                     return false;
                   });
Y
Yu Yang 已提交
219 220 221 222

    if (net->ops_.empty()) {  // Current no aux op is added to network
      return grad_op;
    }
Y
Yu Yang 已提交
223
    net->AppendOp(std::move(grad_op));
Y
Yu Yang 已提交
224
  }
Q
qiaolongfei 已提交
225
  net->SetType("@GENERATED_BACKWARD@");
Y
Yu Yang 已提交
226
  net->CompleteAddOp();
Y
Yu Yang 已提交
227 228 229
  return std::unique_ptr<OperatorBase>(
      static_cast<OperatorBase*>(net.release()));
}
Y
Yu Yang 已提交
230

Y
Yu Yang 已提交
231
// See header for comments
Y
Yu Yang 已提交
232
std::unique_ptr<OperatorBase> Backward(
Y
Yu Yang 已提交
233
    const OperatorBase& forwardOp,
Y
Yu Yang 已提交
234 235
    const std::unordered_set<std::string>& no_grad_vars) {
  std::unordered_set<std::string> no_grad_names;
Q
qijun 已提交
236
  no_grad_names.reserve(no_grad_vars.size() + 1);
Y
Yu Yang 已提交
237

238
  no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
239

Y
Yu Yang 已提交
240
  for (auto& name : no_grad_vars) {
241
    no_grad_names.insert(name + kGradVarSuffix);
Y
Yu Yang 已提交
242
  }
Y
Yu Yang 已提交
243
  size_t uid = 0;
244 245
  std::unordered_map<std::string, std::string> grad_to_var;
  return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid);
Y
Yu Yang 已提交
246
}
Y
Yi Wang 已提交
247

F
fengjiayi 已提交
248 249 250 251 252 253 254 255 256
// ====================================  //

static bool AllGradInSet(const std::vector<std::string>& names,
                         const std::unordered_set<std::string>& set) {
  for (const std::string& name : names) {
    if (!set.count(GradVarName(name))) {
      return false;
    }
  }
Y
Yang Yang(Tony) 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    sout << "All input {";
    for (auto& name : names) {
      sout << name << ",";
    }
    sout << "} is in {";
    for (auto& name : set) {
      sout << name << ",";
    }
    sout << "}";
    VLOG(10) << sout.str();
  }
F
fengjiayi 已提交
270 271 272
  return true;
}

Y
Yu Yang 已提交
273 274 275 276 277 278 279 280 281
static std::string FwdName(const std::string& grad_name) {
  auto pos = grad_name.find("@GRAD");
  if (pos == std::string::npos) {
    return "";
  } else {
    return grad_name.substr(0, pos);
  }
}

Y
Yu Yang 已提交
282
static void CreateGradVarInBlock(
283 284 285 286
    size_t grad_op_start_index,
    const std::unordered_map<std::string, std::string>& param_name_map,
    BlockDescBind* block_desc,
    std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
287 288 289
  auto ops = block_desc->AllOps();
  for (size_t op_index = grad_op_start_index; op_index < ops.size();
       ++op_index) {
Y
Yu Yang 已提交
290
    std::unordered_set<std::string> new_vars;
Y
Yu Yang 已提交
291 292 293 294 295
    ForEachVarName(ops[op_index]->Outputs(),
                   [&](const std::string& grad_var_name) {
                     if (block_desc->HasVar(grad_var_name)) {
                       return false;
                     }
Q
Qiao Longfei 已提交
296
                     auto var = block_desc->Var(grad_var_name);
Y
Yu Yang 已提交
297
                     new_vars.insert(var->Name());
Y
Yu Yang 已提交
298 299 300 301 302 303 304 305 306 307 308
                     auto it = param_name_map.find(grad_var_name);
                     if (it == param_name_map.end()) {
                       return false;
                     }
                     auto param_var_name = it->second;
                     auto& grad_record = (*grad_var_record)[param_var_name];
                     grad_record.name_ = grad_var_name;
                     grad_record.block_idx_ = block_desc->ID();
                     grad_record.op_idx_ = static_cast<int>(op_index);
                     return false; /* not break */
                   });
Y
Yang Yang(Tony) 已提交
309 310 311 312 313 314 315 316 317 318 319 320
    ops[op_index]->InferVarType(block_desc);
    for (auto& arg : ops[op_index]->OutputArgumentNames()) {
      if (new_vars.find(arg) == new_vars.end()) {
        continue;
      }
      auto pname = FwdName(arg);
      auto* param = block_desc->FindVarRecursive(pname);
      auto* grad = block_desc->FindVar(arg);
      if (param == nullptr) {
        grad->SetDataType(DataType::FP32);
      } else {
        grad->SetDataType(param->GetDataType());
Y
Yu Yang 已提交
321
      }
Q
Qiao Longfei 已提交
322
    }
Y
Yang Yang(Tony) 已提交
323
    ops[op_index]->InferShape(*block_desc);
324 325 326
  }
}

F
fengjiayi 已提交
327
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
328
    const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars,
Y
Yu Yang 已提交
329 330 331
    std::unordered_map<std::string, std::string>* grad_to_var,
    const std::vector<BlockDescBind*>& grad_block =
        std::vector<BlockDescBind*>()) {
F
Update  
fengjiayi 已提交
332
  std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
333
  // All input gradients of forwarding operator do not need to calculate.
F
fengjiayi 已提交
334
  const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
335
  if (AllGradInSet(inputs, *no_grad_vars)) {
F
fengjiayi 已提交
336 337 338
    return grad_op_descs;  // empty vector
  }
  // All output gradients of forwarding operator do not need to calculate.
F
fengjiayi 已提交
339
  const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
340
  if (AllGradInSet(outputs, *no_grad_vars)) {
341
    for (const std::string& name : inputs) {
342
      no_grad_vars->insert(GradVarName(name));
F
fengjiayi 已提交
343 344 345 346
    }
    return grad_op_descs;  // empty vector
  }

Y
Yu Yang 已提交
347 348 349 350
  grad_op_descs =
      OpInfoMap::Instance()
          .Get(op_desc->Type())
          .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block);
F
fengjiayi 已提交
351

F
Update  
fengjiayi 已提交
352 353 354
  std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
  for (auto& desc : grad_op_descs) {
    for (const std::string& in_name : desc->InputArgumentNames()) {
355
      if (no_grad_vars->count(in_name)) {
F
fengjiayi 已提交
356 357 358
        std::string prefix = in_name.substr(
            0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
        std::string new_name = prefix + kZeroVarSuffix;
F
Update  
fengjiayi 已提交
359
        desc->Rename(in_name, new_name);
F
fengjiayi 已提交
360 361 362
        std::unique_ptr<OpDescBind> fill_zeros_op(new OpDescBind(
            "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}));
        pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
F
fengjiayi 已提交
363 364 365
      }
    }
  }
F
fengjiayi 已提交
366

F
fengjiayi 已提交
367
  for (auto& p : pending_fill_zeros_ops) {
F
fengjiayi 已提交
368
    grad_op_descs.insert(grad_op_descs.begin(), std::move(p));
F
fengjiayi 已提交
369
  }
F
fengjiayi 已提交
370 371 372
  return grad_op_descs;
}

Y
Yu Yang 已提交
373 374 375 376 377 378
static BlockDescBind* CreateStepBlock(
    ProgramDescBind& program_desc,
    std::unordered_set<std::string>* no_grad_vars,
    std::unordered_map<std::string, std::string>* grad_to_var,
    int step_block_idx);

F
fengjiayi 已提交
379 380
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
    ProgramDescBind& program_desc, int block_idx,
381 382
    std::unordered_set<std::string>* no_grad_vars,
    std::unordered_map<std::string, std::string>* grad_to_var) {
Y
Yang Yang(Tony) 已提交
383
  VLOG(5) << "MakeBlockBackward";
384
  BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
385
  std::vector<OpDescBind*> op_descs = cur_block->AllOps();
F
Update  
fengjiayi 已提交
386 387
  std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
  size_t grad_desc_idx = 0;
F
Update  
fengjiayi 已提交
388
  std::vector<std::unique_ptr<OpDescBind>> backward_descs;
389

F
fengjiayi 已提交
390
  for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
Y
Yang Yang(Tony) 已提交
391
    VLOG(5) << "Making backward " << (*it)->Type() << " op";
Y
Yu Yang 已提交
392
    std::vector<std::unique_ptr<OpDescBind>> op_grads;
F
fengjiayi 已提交
393

Y
Yang Yang(Tony) 已提交
394
    if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
395
      int step_block_idx = (*it)->GetBlockAttr("step_block");
Y
Yu Yang 已提交
396 397 398 399
      BlockDescBind* backward_block = CreateStepBlock(
          program_desc, no_grad_vars, grad_to_var, step_block_idx);
      op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
    } else if ((*it)->Type() == "conditional_block") {
Y
Yu Yang 已提交
400
      BlockDescBind* backward_block =
Y
Yu Yang 已提交
401 402
          CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
                          (*it)->GetBlockAttr("block"));
Y
Yu Yang 已提交
403 404 405
      op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
    } else {
      op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
F
fengjiayi 已提交
406 407
    }

Y
Yang Yang(Tony) 已提交
408 409 410 411 412 413 414 415 416
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      sout << "Made ";
      for (auto& op_grad : op_grads) {
        sout << op_grad->Type() << " ";
      }
      VLOG(10) << sout.str();
    }

F
Update  
fengjiayi 已提交
417
    for (const auto& desc : op_grads) {
F
fengjiayi 已提交
418
      for (const std::string& out_name : desc->OutputArgumentNames()) {
419 420 421 422 423
        if (out_name.find("@GRAD") == std::string::npos) {
          // Not all outputs of a backward operator is a gradient. Only gradient
          // need to be sum. Skip variables are not gradient.
          continue;
        }
F
Update  
fengjiayi 已提交
424 425 426 427
        dup_out_ops[out_name].emplace_back(grad_desc_idx);
      }
      ++grad_desc_idx;
    }
F
fengjiayi 已提交
428 429 430
    std::transform(
        op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
        [](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
F
Update  
fengjiayi 已提交
431
  }
Y
Yang Yang(Tony) 已提交
432 433

  VLOG(5) << "Appending Sums";
F
Update  
fengjiayi 已提交
434
  // Check whether some variables are written more than once
F
Update  
fengjiayi 已提交
435
  std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
F
Update  
fengjiayi 已提交
436 437 438 439 440
  for (const auto& dup : dup_out_ops) {
    const std::string& out_name = dup.first;
    const std::vector<size_t> dup_op = dup.second;
    if (out_name != kEmptyVarName && dup_op.size() > 1) {
      std::vector<std::string> sum_op_inputs;
Y
Yang Yang(Tony) 已提交
441
      std::string next_g_name = out_name;
F
Update  
fengjiayi 已提交
442
      for (size_t i = 0; i < dup_op.size(); ++i) {
Y
Yang Yang(Tony) 已提交
443 444
        VLOG(10) << backward_descs[dup_op[i]]->Type() << " has " << out_name
                 << " duplicated";
F
Update  
fengjiayi 已提交
445
        std::string new_name = out_name + "@RENAME@" + std::to_string(i);
Y
Yang Yang(Tony) 已提交
446 447
        backward_descs[dup_op[i]]->RenameOutput(out_name, new_name);
        backward_descs[dup_op[i]]->RenameInput(out_name, next_g_name);
F
Update  
fengjiayi 已提交
448
        sum_op_inputs.emplace_back(new_name);
Y
Yang Yang(Tony) 已提交
449
        next_g_name = sum_op_inputs.back();
F
Update  
fengjiayi 已提交
450
      }
F
fengjiayi 已提交
451 452 453
      std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
          "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
      pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
F
Update  
fengjiayi 已提交
454 455
    }
  }
Y
Yang Yang(Tony) 已提交
456

F
Update  
fengjiayi 已提交
457
  pending_sum_ops.sort(
F
Update  
fengjiayi 已提交
458 459 460 461
      [](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
         const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
        return a.first > b.first;
      });
F
Update  
fengjiayi 已提交
462
  for (auto& p : pending_sum_ops) {
F
Update  
fengjiayi 已提交
463 464
    backward_descs.insert(backward_descs.begin() + p.first + 1,
                          std::move(p.second));
F
Update  
fengjiayi 已提交
465
  }
466

Y
Yang Yang(Tony) 已提交
467 468
  VLOG(5) << "MakeBlockBackward Finished";

F
fengjiayi 已提交
469 470 471
  return backward_descs;
}

Y
Yu Yang 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
static BlockDescBind* CreateStepBlock(
    ProgramDescBind& program_desc,
    std::unordered_set<std::string>* no_grad_vars,
    std::unordered_map<std::string, std::string>* grad_to_var,
    int step_block_idx) {
  auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
                                                   no_grad_vars, grad_to_var);
  BlockDescBind* backward_block =
      program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
  for (auto& ptr : backward_block_op_descs) {
    backward_block->AppendAllocatedOp(move(ptr));
  }
  return backward_block;
}

Q
qiaolongfei 已提交
487 488 489
ParamGradInfoMap AppendBackward(
    ProgramDescBind& program_desc, const VarDescBind& target,
    const std::unordered_set<std::string>& no_grad_vars) {
F
fengjiayi 已提交
490 491 492 493 494 495
  std::unordered_set<std::string> no_grad_var_names;
  no_grad_var_names.reserve(no_grad_vars.size() + 1);
  no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
  for (auto& name : no_grad_vars) {
    no_grad_var_names.insert(GradVarName(name));
  }
496

F
fengjiayi 已提交
497
  const int root_block_idx = 0;
498
  auto root_block = program_desc.MutableBlock(root_block_idx);
499 500

  std::string fill_one_op_out = GradVarName(target.Name());
501 502
  bool is_scalar = target.Shape() == std::vector<int64_t>{1};
  PADDLE_ENFORCE(is_scalar, "target should be scalar");
Y
Yu Yang 已提交
503 504
  VLOG(3) << "backward from loss=" << target.Name()
          << " data_type=" << target.GetDataType();
505 506
  std::unique_ptr<OpDescBind> fill_one_op(
      new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
507
                     {{"shape", std::vector<int>{1}},
508
                      {"value", static_cast<float>(1.0)},
F
fengjiayi 已提交
509
                      {"dtype", target.GetDataType()}}));
Q
QI JUN 已提交
510 511 512
  // infer var type of fill_one_op
  fill_one_op->InferVarType(root_block);

513 514
  root_block->AppendAllocatedOp(std::move(fill_one_op));
  size_t forward_op_num = root_block->OpSize();
515
  size_t forward_block_num = program_desc.Size();
Y
Yu Yang 已提交
516 517

  // Insert backward operators
518 519 520
  std::unordered_map<std::string, std::string> grad_to_var;
  auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
                                             &no_grad_var_names, &grad_to_var);
Y
Yu Yang 已提交
521

F
fengjiayi 已提交
522
  for (auto& ptr : backward_op_descs) {
523
    root_block->AppendAllocatedOp(std::move(ptr));
524
  }
Q
Qiao Longfei 已提交
525 526 527 528 529 530
  // Create Variable

  // Create target gradient variable
  std::unordered_map<std::string, GradVarInfo> retv;

  auto var = root_block->Var(fill_one_op_out);
Y
Yu Yang 已提交
531
  var->SetDataType(target.GetDataType());
Q
Qiao Longfei 已提交
532 533 534 535 536
  var->SetShape(target.Shape());
  auto& target_grad = retv[target.Name()];
  target_grad.name_ = fill_one_op_out;
  target_grad.block_idx_ = root_block_idx;
  target_grad.op_idx_ = static_cast<int>(forward_op_num);
537 538

  // create grad_var for all blocks in this program
539
  CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
540 541
  for (size_t block_index = forward_block_num;
       block_index < program_desc.Size(); ++block_index) {
542
    CreateGradVarInBlock(0, grad_to_var, program_desc.MutableBlock(block_index),
543
                         &retv);
F
fengjiayi 已提交
544
  }
Y
Yu Yang 已提交
545
  return retv;
F
Update  
fengjiayi 已提交
546 547
}

Y
Yu Yang 已提交
548 549
}  // namespace framework
}  // namespace paddle