backward.cc 20.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/framework/backward.h"
Y
Yu Yang 已提交
16
#include "paddle/operators/net_op.h"
D
dongzhihong 已提交
17

F
fengjiayi 已提交
18
#include <deque>
D
dongzhihong 已提交
19
#include <list>
Y
Yu Yang 已提交
20
#include <memory>
Y
Yu Yang 已提交
21
#include <unordered_set>
Y
Yu Yang 已提交
22

F
fengjiayi 已提交
23
#include "paddle/framework/block_desc.h"
24
#include "paddle/framework/op_registry.h"
25
#include "paddle/operators/dynamic_recurrent_op.h"
Y
Yan Chunwei 已提交
26
#include "paddle/operators/net_op.h"
Y
Yan Chunwei 已提交
27
#include "paddle/operators/recurrent_op.h"
Y
Yu Yang 已提交
28 29 30 31

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
32
static inline std::unique_ptr<OperatorBase> CreateGradOp(
33 34
    const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
    std::unordered_map<std::string, std::string>* grad_to_var) {
Y
Yu Yang 已提交
35 36 37 38 39 40
  OpDescBind op_desc;
  op_desc.SetInputMap(op.Inputs());
  op_desc.SetOutputMap(op.Outputs());
  op_desc.SetType(op.Type());
  op_desc.SetAttrMap(op.Attrs());
  auto& info = OpInfoMap::Instance().Get(op.Type());
41
  auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
Y
Yu Yang 已提交
42 43
  std::vector<std::unique_ptr<OperatorBase>> grad_ops;
  grad_ops.reserve(grad_descs.size());
Y
Yu Yang 已提交
44 45 46
  std::transform(grad_descs.begin(), grad_descs.end(),
                 std::back_inserter(grad_ops),
                 [](const std::unique_ptr<OpDescBind>& grad_desc) {
Y
Yu Yang 已提交
47
                   return OpRegistry::CreateOp(*grad_desc);
Y
Yu Yang 已提交
48
                 });
Y
Yu Yang 已提交
49
  PADDLE_ENFORCE(!grad_ops.empty());
Y
Yu Yang 已提交
50 51 52 53 54 55 56
  if (grad_ops.size() == 1) {
    return std::move(grad_ops[0]);
  } else {
    auto net_op = new operators::NetOp();
    for (auto& grad_op : grad_ops) {
      net_op->AppendOp(std::move(grad_op));
    }
Y
Yu Yang 已提交
57
    net_op->CompleteAddOp();
Y
Yu Yang 已提交
58 59 60 61
    return std::unique_ptr<OperatorBase>(net_op);
  }
}

Y
Yu Yang 已提交
62
template <typename Map, typename T>
Q
qiaolongfei 已提交
63
static void ForEachVarName(const Map& names, T callback) {
Y
Yu Yang 已提交
64
  for (auto& name : names) {
Y
Yu Yang 已提交
65
    for (auto& n : name.second) {
66
      if (callback(n)) return;
Y
Yu Yang 已提交
67 68
    }
  }
Y
Yu Yang 已提交
69 70
}

Y
Yan Chunwei 已提交
71
// return whether all the names + suffixes in the set
Y
Yu Yang 已提交
72
static bool AllInSet(
Y
Yu Yang 已提交
73
    const std::map<std::string, std::vector<std::string>>& names,
Y
Yu Yang 已提交
74
    const std::string& suffix, const std::unordered_set<std::string>& set) {
75 76 77 78
  bool all_in_set = true;
  ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
    all_in_set = set.find(n + suffix) != set.end();
    return !all_in_set;
Y
Yu Yang 已提交
79
  });
80
  return all_in_set;
Y
Yu Yang 已提交
81 82
}

Y
Yu Yang 已提交
83 84
static std::unique_ptr<OperatorBase> NOP() {
  auto net_op = new operators::NetOp();
Q
qiaolongfei 已提交
85
  net_op->SetType("@NOP@");
Y
Yu Yang 已提交
86
  net_op->CompleteAddOp();
Y
Yu Yang 已提交
87
  return std::unique_ptr<OperatorBase>(net_op);
Y
Yu Yang 已提交
88 89
}

Y
Yan Chunwei 已提交
90
//  Get backward operator from a forward operator, a recursive implementation.
Y
Yu Yang 已提交
91 92 93
//
//  no_grad_names the gradient variable names without gradient calculating.
//
94 95 96
//  uniq_id is a unique index used inside recursively calling
//  BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and
//  pass `uniq_id` through recursive calling.
Y
Yu Yang 已提交
97
//
Y
Yan Chunwei 已提交
98 99
//  returns The backward operator. In a simple situation, it may be a simple
//  operator, in a complex situation, it maybe a NetOp.
Y
Yu Yang 已提交
100 101
//
//  See Backward.h for details
Y
Yu Yang 已提交
102
static std::unique_ptr<OperatorBase> BackwardRecursive(
Y
Yu Yang 已提交
103
    const OperatorBase& forwardOp,
104 105 106
    std::unordered_set<std::string>& no_grad_names,
    std::unordered_map<std::string, std::string>* grad_to_var,
    size_t& uniq_id) {
Y
Yu Yang 已提交
107 108
  //  If all input gradients of forwarding operator do not need to calculate,
  //  just return an NOP. Not return null ptr because NOP does not take
Q
typo  
qiaolongfei 已提交
109
  //  too much time for calculation, but it is useful for simplifying logic.
110
  if (AllInSet(forwardOp.Inputs() /*names*/, kGradVarSuffix /*suffix*/,
Y
Yan Chunwei 已提交
111
               no_grad_names /*set*/)) {
Y
Yu Yang 已提交
112
    return NOP();
Y
Yu Yang 已提交
113 114
  }

115 116
  //  All output gradients of forwarding operator do not need to calculate.
  //  Then all input gradients cannot be computed at all, and we put them into
Y
Yu Yang 已提交
117
  //  `no_grad_names` set. Return an NOP.
Q
qiaolongfei 已提交
118
  if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
Y
Yan Chunwei 已提交
119
               no_grad_names /*set*/)) {
Q
qiaolongfei 已提交
120
    ForEachVarName(forwardOp.Inputs(),
Y
Yu Yang 已提交
121 122 123 124
                   [&no_grad_names](const std::string& name) -> bool {
                     no_grad_names.insert(GradVarName(name));
                     return false;
                   });
Y
Yu Yang 已提交
125
    return NOP();
Y
Yu Yang 已提交
126 127
  }

Y
Yu Yang 已提交
128
  // Returned gradient network
Y
Yu Yang 已提交
129
  auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
Y
Yu Yang 已提交
130 131

  if (forwardOp.IsNetOp()) {
Y
Yu Yang 已提交
132
    // Because forwardOp is a net op, it can static_cast.
Y
Yan Chunwei 已提交
133
    auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
Y
Yu Yang 已提交
134

135
    // Map from output gradient variable name to operator's indices in
Y
Yan Chunwei 已提交
136
    // backward net's ops_. That operator generates that variable.
Y
Yu Yang 已提交
137 138 139
    std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;

    size_t local_op_id = 0;
Y
Yan Chunwei 已提交
140
    // reversely travel forwardNet and collect all duplicate outputs.
Y
Yu Yang 已提交
141
    for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
Y
Yu Yang 已提交
142
         ++it, ++local_op_id) {
Y
Yu Yang 已提交
143
      auto& fwd = *it;
144
      auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id);
Q
qiaolongfei 已提交
145
      ForEachVarName(bwd->Outputs(),
Y
Yu Yang 已提交
146 147 148 149
                     [&dup_output_ops, local_op_id](const std::string& out) {
                       dup_output_ops[out].emplace_back(local_op_id);
                       return false;
                     });
Y
Yu Yang 已提交
150
      net->AppendOp(std::move(bwd));
D
dongzhihong 已提交
151
    }
Y
Yu Yang 已提交
152
    // Get unique ID for this method.
D
dongzhihong 已提交
153
    auto uid = uniq_id++;
D
dongzhihong 已提交
154
    // TODO(dzh): more comment
Y
Yan Chunwei 已提交
155 156 157 158 159
    // multiple operators which have the same output (y for example) may
    // overwrite the same y variable when backward, special operations are token
    // to handle this case. For each duplicate output, rename it to an alias
    // (original name with a offset), append an `add` op for its operator,
    // and finally sum all the alias variable to the final output variable y.
Y
Yu Yang 已提交
160
    using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
Y
Yu Yang 已提交
161
    std::list<Pos> insert_position;
D
dongzhihong 已提交
162
    for (auto& dup_output_op : dup_output_ops) {
D
dongzhihong 已提交
163
      const std::string& name = dup_output_op.first;
Q
qijun 已提交
164 165 166
      // duplicate @Empty@ don't need to be added
      if (name == kEmptyVarName) continue;

D
dongzhihong 已提交
167
      auto& dup_op = dup_output_op.second;
Y
Yan Chunwei 已提交
168
      // no duplicate output
D
dongzhihong 已提交
169 170
      if (dup_op.size() == 1) continue;

Y
Yan Chunwei 已提交
171 172
      // process the duplicate outputs
      std::vector<std::string> dup_outputs;
D
dongzhihong 已提交
173
      for (size_t i = 0; i < dup_op.size(); ++i) {
Y
Yan Chunwei 已提交
174
        // rename each duplicate output to an alias
D
dongzhihong 已提交
175
        auto op_offset = dup_op[i];
D
dongzhihong 已提交
176 177 178
        dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
                              std::to_string(i));
        net->ops_[op_offset]->Rename(name, dup_outputs.back());
D
dongzhihong 已提交
179
      }
180 181 182 183 184
      // collect all the offset for each alias,
      // insert a sum operator to add all aliases to output
      insert_position.push_back(
          {dup_op.back(), OpRegistry::CreateOp("sum", {{"X", dup_outputs}},
                                               {{"Out", {name}}}, {})});
D
dongzhihong 已提交
185
    }
Y
Yu Yang 已提交
186

187
    // make sure the inserted `sum` ops follow the BFS order.
Y
Yu Yang 已提交
188
    insert_position.sort(
D
dongzhihong 已提交
189
        [](const Pos& l, const Pos& r) { return l.first > r.first; });
Y
Yu Yang 已提交
190 191

    for (auto& pos : insert_position) {
Y
Yu Yang 已提交
192
      net->InsertOp(pos.first + 1, std::move(pos.second));
D
dongzhihong 已提交
193
    }
Y
Yu Yang 已提交
194
  } else {
195
    std::unique_ptr<OperatorBase> grad_op(
196
        CreateGradOp(forwardOp, no_grad_names, grad_to_var));
Y
Yu Yang 已提交
197

Y
Yu Yang 已提交
198 199
    ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
                                          const std::string& grad_input) {
200
      if (no_grad_names.count(grad_input)) {
Y
Yu Yang 已提交
201
        // +1 for \0
202
        std::string prefix = grad_input.substr(
Y
Yu Yang 已提交
203
            0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
Q
qiaolongfei 已提交
204
        grad_op->Rename(grad_input, prefix + kZeroVarSuffix);
Y
Yu Yang 已提交
205 206 207

        // If part of input gradient of that operator is not calculated, fill
        // zero variables to that input gradient.
D
dangqingqing 已提交
208 209
        net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", {{"X", {prefix}}},
                                           {{"Y", {grad_input}}}, {}));
210
      }
Y
Yu Yang 已提交
211 212 213
      return false;
    });

Q
qiaolongfei 已提交
214 215
    ForEachVarName(grad_op->Outputs(),
                   [&no_grad_names, &grad_op](const std::string& grad_output) {
Y
Yu Yang 已提交
216
                     if (no_grad_names.count(grad_output)) {
Q
qiaolongfei 已提交
217
                       grad_op->Rename(grad_output, kEmptyVarName);
Y
Yu Yang 已提交
218 219 220
                     }
                     return false;
                   });
Y
Yu Yang 已提交
221

Y
Yan Chunwei 已提交
222
    // process recurrent gradient op as a special operator.
223
    if (forwardOp.Type() == "recurrent") {
F
Fix bug  
fengjiayi 已提交
224
      // NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
225
      // or this will result in infinite loop.
Y
Yan Chunwei 已提交
226 227 228 229 230 231 232 233
      const auto& rnnop =
          *static_cast<const operators::RecurrentOp*>(&forwardOp);
      auto rnn_grad_op =
          static_cast<operators::RecurrentGradientOp*>(grad_op.get());
      const auto& stepnet_op =
          *static_cast<const OperatorBase*>(&rnnop.stepnet());
      // create stepnet's gradient op
      rnn_grad_op->set_stepnet(
234
          BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
235 236 237 238 239 240 241 242 243 244 245 246
    } else if (forwardOp.Type() == "dynamic_recurrent") {
      // NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
      // or this will result in infinite loop.
      const auto& rnnop =
          *static_cast<const operators::DynamicRecurrentOp*>(&forwardOp);
      auto rnn_grad_op =
          static_cast<operators::DynamicRecurrentGradientOp*>(grad_op.get());
      const auto& stepnet_op =
          *static_cast<const OperatorBase*>(&rnnop.rnn.GetStepUnit());
      // create stepnet's gradient op
      rnn_grad_op->rnn.SetStepUnit(
          BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
Y
Yan Chunwei 已提交
247 248
    }

Y
Yu Yang 已提交
249 250 251
    if (net->ops_.empty()) {  // Current no aux op is added to network
      return grad_op;
    }
Y
Yu Yang 已提交
252
    net->AppendOp(std::move(grad_op));
Y
Yu Yang 已提交
253
  }
Q
qiaolongfei 已提交
254
  net->SetType("@GENERATED_BACKWARD@");
Y
Yu Yang 已提交
255
  net->CompleteAddOp();
Y
Yu Yang 已提交
256 257 258
  return std::unique_ptr<OperatorBase>(
      static_cast<OperatorBase*>(net.release()));
}
Y
Yu Yang 已提交
259

Y
Yu Yang 已提交
260
// See header for comments
Y
Yu Yang 已提交
261
std::unique_ptr<OperatorBase> Backward(
Y
Yu Yang 已提交
262
    const OperatorBase& forwardOp,
Y
Yu Yang 已提交
263 264
    const std::unordered_set<std::string>& no_grad_vars) {
  std::unordered_set<std::string> no_grad_names;
Q
qijun 已提交
265
  no_grad_names.reserve(no_grad_vars.size() + 1);
Y
Yu Yang 已提交
266

267
  no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
268

Y
Yu Yang 已提交
269
  for (auto& name : no_grad_vars) {
270
    no_grad_names.insert(name + kGradVarSuffix);
Y
Yu Yang 已提交
271
  }
Y
Yu Yang 已提交
272
  size_t uid = 0;
273 274
  std::unordered_map<std::string, std::string> grad_to_var;
  return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid);
Y
Yu Yang 已提交
275
}
Y
Yi Wang 已提交
276

F
fengjiayi 已提交
277 278 279 280 281 282 283 284 285 286 287 288
// ====================================  //

static bool AllGradInSet(const std::vector<std::string>& names,
                         const std::unordered_set<std::string>& set) {
  for (const std::string& name : names) {
    if (!set.count(GradVarName(name))) {
      return false;
    }
  }
  return true;
}

Y
Yu Yang 已提交
289 290 291 292 293 294 295 296 297
static std::string FwdName(const std::string& grad_name) {
  auto pos = grad_name.find("@GRAD");
  if (pos == std::string::npos) {
    return "";
  } else {
    return grad_name.substr(0, pos);
  }
}

Y
Yu Yang 已提交
298
static void CreateGradVarInBlock(
299 300 301 302
    size_t grad_op_start_index,
    const std::unordered_map<std::string, std::string>& param_name_map,
    BlockDescBind* block_desc,
    std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
303 304 305
  auto ops = block_desc->AllOps();
  for (size_t op_index = grad_op_start_index; op_index < ops.size();
       ++op_index) {
Q
Qiao Longfei 已提交
306
    bool need_infer_shape = false;
Y
Yu Yang 已提交
307
    std::unordered_set<std::string> new_vars;
Y
Yu Yang 已提交
308 309 310 311 312
    ForEachVarName(ops[op_index]->Outputs(),
                   [&](const std::string& grad_var_name) {
                     if (block_desc->HasVar(grad_var_name)) {
                       return false;
                     }
Q
Qiao Longfei 已提交
313 314
                     need_infer_shape = true;
                     auto var = block_desc->Var(grad_var_name);
Y
Yu Yang 已提交
315
                     new_vars.insert(var->Name());
Y
Yu Yang 已提交
316 317 318 319 320 321 322 323 324 325 326
                     auto it = param_name_map.find(grad_var_name);
                     if (it == param_name_map.end()) {
                       return false;
                     }
                     auto param_var_name = it->second;
                     auto& grad_record = (*grad_var_record)[param_var_name];
                     grad_record.name_ = grad_var_name;
                     grad_record.block_idx_ = block_desc->ID();
                     grad_record.op_idx_ = static_cast<int>(op_index);
                     return false; /* not break */
                   });
Q
Qiao Longfei 已提交
327
    if (need_infer_shape) {
Q
QI JUN 已提交
328
      ops[op_index]->InferVarType(block_desc);
Y
Yu Yang 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
      for (auto& arg : ops[op_index]->OutputArgumentNames()) {
        if (new_vars.find(arg) == new_vars.end()) {
          continue;
        }
        auto pname = FwdName(arg);
        auto* param = block_desc->FindVar(pname);
        auto* grad = block_desc->FindVar(arg);
        if (param == nullptr) {
          LOG(WARNING) << "Cannot find forward variable of " << arg
                       << ". Set its gradient to FP32";
          grad->SetDataType(DataType::FP32);
        } else {
          grad->SetDataType(param->GetDataType());
        }
      }
Q
Qiao Longfei 已提交
344 345
      ops[op_index]->InferShape(*block_desc);
    }
346 347 348
  }
}

F
fengjiayi 已提交
349
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
350
    const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars,
351
    std::unordered_map<std::string, std::string>* grad_to_var) {
F
Update  
fengjiayi 已提交
352
  std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
353
  // All input gradients of forwarding operator do not need to calculate.
F
fengjiayi 已提交
354
  const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
355
  if (AllGradInSet(inputs, *no_grad_vars)) {
F
fengjiayi 已提交
356 357 358
    return grad_op_descs;  // empty vector
  }
  // All output gradients of forwarding operator do not need to calculate.
F
fengjiayi 已提交
359
  const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
360
  if (AllGradInSet(outputs, *no_grad_vars)) {
361
    for (const std::string& name : inputs) {
362
      no_grad_vars->insert(GradVarName(name));
F
fengjiayi 已提交
363 364 365 366
    }
    return grad_op_descs;  // empty vector
  }

367 368
  grad_op_descs = OpInfoMap::Instance()
                      .Get(op_desc->Type())
369
                      .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
F
fengjiayi 已提交
370

F
Update  
fengjiayi 已提交
371 372 373
  std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
  for (auto& desc : grad_op_descs) {
    for (const std::string& in_name : desc->InputArgumentNames()) {
374
      if (no_grad_vars->count(in_name)) {
F
fengjiayi 已提交
375 376 377
        std::string prefix = in_name.substr(
            0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
        std::string new_name = prefix + kZeroVarSuffix;
F
Update  
fengjiayi 已提交
378
        desc->Rename(in_name, new_name);
F
fengjiayi 已提交
379 380 381
        std::unique_ptr<OpDescBind> fill_zeros_op(new OpDescBind(
            "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}));
        pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
F
fengjiayi 已提交
382 383 384
      }
    }
  }
F
fengjiayi 已提交
385

F
fengjiayi 已提交
386
  for (auto& p : pending_fill_zeros_ops) {
F
fengjiayi 已提交
387
    grad_op_descs.insert(grad_op_descs.begin(), std::move(p));
F
fengjiayi 已提交
388
  }
F
fengjiayi 已提交
389 390 391
  return grad_op_descs;
}

F
fengjiayi 已提交
392 393
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
    ProgramDescBind& program_desc, int block_idx,
394 395
    std::unordered_set<std::string>* no_grad_vars,
    std::unordered_map<std::string, std::string>* grad_to_var) {
396
  BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
397
  std::vector<OpDescBind*> op_descs = cur_block->AllOps();
F
Update  
fengjiayi 已提交
398 399
  std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
  size_t grad_desc_idx = 0;
F
Update  
fengjiayi 已提交
400
  std::vector<std::unique_ptr<OpDescBind>> backward_descs;
401

F
fengjiayi 已提交
402
  for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
F
Update  
fengjiayi 已提交
403
    std::vector<std::unique_ptr<OpDescBind>> op_grads =
404
        MakeOpGrad(*it, no_grad_vars, grad_to_var);
F
fengjiayi 已提交
405 406 407

    if ((*it)->Type() == "recurrent") {
      PADDLE_ENFORCE_EQ(
408
          op_grads.size(), static_cast<size_t>(1),
F
fengjiayi 已提交
409
          "rnn_op's gradient process should contain only one op.");
410
      int step_block_idx = (*it)->GetBlockAttr("step_block");
411 412
      auto backward_block_op_descs = MakeBlockBackward(
          program_desc, step_block_idx, no_grad_vars, grad_to_var);
F
fengjiayi 已提交
413 414
      BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
      for (auto& ptr : backward_block_op_descs) {
415
        backward_block->AppendAllocatedOp(std::move(ptr));
F
fengjiayi 已提交
416 417 418 419
      }
      op_grads[0]->SetBlockAttr("step_block", *backward_block);
    }

F
Update  
fengjiayi 已提交
420
    for (const auto& desc : op_grads) {
F
fengjiayi 已提交
421
      for (const std::string& out_name : desc->OutputArgumentNames()) {
F
Update  
fengjiayi 已提交
422 423 424 425
        dup_out_ops[out_name].emplace_back(grad_desc_idx);
      }
      ++grad_desc_idx;
    }
F
fengjiayi 已提交
426 427 428
    std::transform(
        op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
        [](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
F
Update  
fengjiayi 已提交
429 430
  }
  // Check whether some variables are written more than once
F
Update  
fengjiayi 已提交
431
  std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
F
Update  
fengjiayi 已提交
432 433 434 435 436 437 438
  for (const auto& dup : dup_out_ops) {
    const std::string& out_name = dup.first;
    const std::vector<size_t> dup_op = dup.second;
    if (out_name != kEmptyVarName && dup_op.size() > 1) {
      std::vector<std::string> sum_op_inputs;
      for (size_t i = 0; i < dup_op.size(); ++i) {
        std::string new_name = out_name + "@RENAME@" + std::to_string(i);
F
Update  
fengjiayi 已提交
439
        backward_descs[dup_op[i]]->Rename(out_name, new_name);
F
Update  
fengjiayi 已提交
440 441
        sum_op_inputs.emplace_back(new_name);
      }
F
fengjiayi 已提交
442 443 444
      std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
          "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
      pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
F
Update  
fengjiayi 已提交
445 446 447
    }
  }
  pending_sum_ops.sort(
F
Update  
fengjiayi 已提交
448 449 450 451
      [](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
         const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
        return a.first > b.first;
      });
F
Update  
fengjiayi 已提交
452
  for (auto& p : pending_sum_ops) {
F
Update  
fengjiayi 已提交
453 454
    backward_descs.insert(backward_descs.begin() + p.first + 1,
                          std::move(p.second));
F
Update  
fengjiayi 已提交
455
  }
456

F
fengjiayi 已提交
457 458 459
  return backward_descs;
}

Q
qiaolongfei 已提交
460 461 462
ParamGradInfoMap AppendBackward(
    ProgramDescBind& program_desc, const VarDescBind& target,
    const std::unordered_set<std::string>& no_grad_vars) {
F
fengjiayi 已提交
463 464 465 466 467 468
  std::unordered_set<std::string> no_grad_var_names;
  no_grad_var_names.reserve(no_grad_vars.size() + 1);
  no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
  for (auto& name : no_grad_vars) {
    no_grad_var_names.insert(GradVarName(name));
  }
469

F
fengjiayi 已提交
470
  const int root_block_idx = 0;
471
  auto root_block = program_desc.MutableBlock(root_block_idx);
472 473

  // insert fill one op for target
Q
Qiao Longfei 已提交
474
  // TODO(qiao) add some check to the target.
475
  std::string fill_one_op_out = GradVarName(target.Name());
Q
Qiao Longfei 已提交
476 477 478 479 480
  std::vector<int64_t> target_shape_desc = target.Shape();
  std::vector<int> target_shape;
  std::transform(target_shape_desc.begin(), target_shape_desc.end(),
                 std::back_inserter(target_shape),
                 [](int64_t dim) { return static_cast<int>(dim); });
Y
Yu Yang 已提交
481 482
  VLOG(3) << "backward from loss=" << target.Name()
          << " data_type=" << target.GetDataType();
483 484
  std::unique_ptr<OpDescBind> fill_one_op(
      new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
Q
Qiao Longfei 已提交
485
                     {{"shape", target_shape},
486
                      {"value", static_cast<float>(1.0)},
Y
Yu Yang 已提交
487
                      {"data_type", target.GetDataType()}}));
Q
QI JUN 已提交
488 489 490
  // infer var type of fill_one_op
  fill_one_op->InferVarType(root_block);

491 492
  root_block->AppendAllocatedOp(std::move(fill_one_op));
  size_t forward_op_num = root_block->OpSize();
493
  size_t forward_block_num = program_desc.Size();
Y
Yu Yang 已提交
494 495

  // Insert backward operators
496 497 498
  std::unordered_map<std::string, std::string> grad_to_var;
  auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
                                             &no_grad_var_names, &grad_to_var);
Y
Yu Yang 已提交
499

F
fengjiayi 已提交
500
  for (auto& ptr : backward_op_descs) {
501
    root_block->AppendAllocatedOp(std::move(ptr));
502
  }
Q
Qiao Longfei 已提交
503 504 505 506 507 508
  // Create Variable

  // Create target gradient variable
  std::unordered_map<std::string, GradVarInfo> retv;

  auto var = root_block->Var(fill_one_op_out);
Y
Yu Yang 已提交
509
  var->SetDataType(target.GetDataType());
Q
Qiao Longfei 已提交
510 511 512 513 514
  var->SetShape(target.Shape());
  auto& target_grad = retv[target.Name()];
  target_grad.name_ = fill_one_op_out;
  target_grad.block_idx_ = root_block_idx;
  target_grad.op_idx_ = static_cast<int>(forward_op_num);
515 516

  // create grad_var for all blocks in this program
517
  CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
518 519
  for (size_t block_index = forward_block_num;
       block_index < program_desc.Size(); ++block_index) {
520
    CreateGradVarInBlock(0, grad_to_var, program_desc.MutableBlock(block_index),
521
                         &retv);
F
fengjiayi 已提交
522
  }
Y
Yu Yang 已提交
523
  return retv;
F
Update  
fengjiayi 已提交
524 525
}

Y
Yu Yang 已提交
526 527
}  // namespace framework
}  // namespace paddle