while_op.cc 18.9 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yang Yang(Tony) 已提交
14 15

#include <vector>
Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
S
sneaxiy 已提交
20
#include "paddle/fluid/framework/var_type.h"
Y
Yi Wang 已提交
21
#include "paddle/fluid/operators/detail/safe_ref.h"
Y
Yang Yang(Tony) 已提交
22 23 24 25 26 27 28

namespace paddle {
namespace operators {

using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor;

Y
Yang Yu 已提交
29 30 31 32 33 34
static constexpr char kStepBlock[] = "sub_block";
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
S
sneaxiy 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";

namespace {  // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString(
    const std::vector<std::string> &vars) {
  std::string str = "Skip " + std::to_string(vars.size()) +
                    " var(s) in eager deletion mode: ";
  for (auto &var : vars) {
    str.append(var);
    str.push_back(' ');
  }
  return str;
}
}  // NOLINT
Y
Yang Yang(Tony) 已提交
49 50 51 52 53 54 55 56

class WhileOp : public framework::OperatorBase {
 public:
  WhileOp(const std::string &type, const framework::VariableNameMap &inputs,
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
      : framework::OperatorBase(type, inputs, outputs, attrs) {}

57 58 59
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
Y
Yang Yang(Tony) 已提交
60 61 62 63
    PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
    auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
    PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));

D
dzhwinter 已提交
64
    framework::Executor executor(dev_place);
Y
Yu Yang 已提交
65
    auto *block = Attr<framework::BlockDesc *>(kStepBlock);
D
dzhwinter 已提交
66

Y
Yang Yang(Tony) 已提交
67 68 69 70 71
    auto *program = block->Program();

    auto step_scopes =
        scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();

Y
Yang Yu 已提交
72 73
    PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
                   "Condition of while op must in CPU memory.");
X
Xin Pan 已提交
74

C
chengduo 已提交
75
    bool is_test = Attr<bool>("is_test");
S
sneaxiy 已提交
76 77 78
    auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
    if (framework::GetEagerDeletionThreshold() >= 0) {
      VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
S
fix bug  
sneaxiy 已提交
79 80
    }

S
sneaxiy 已提交
81
    auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
Y
Yang Yang(Tony) 已提交
82 83 84
    while (cond.data<bool>()[0]) {
      auto &current_scope = scope.NewScope();
      step_scopes->push_back(&current_scope);
C
chengduoZH 已提交
85
      executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
C
chengduo 已提交
86 87 88
      if (is_test) {
        scope.DeleteScope(&current_scope);
      }
Y
Yang Yang(Tony) 已提交
89 90 91 92 93 94
    }
  }
};

class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
95
  void Make() override {
Y
Yang Yu 已提交
96
    AddInput(kX,
Y
Yang Yang(Tony) 已提交
97 98 99 100 101 102 103
             "A set of variables, which are required by operators inside the "
             "block of While Op.")
        .AsDuplicable();
    AddInput(
        kCondition,
        "(Bool) An scalar. When it's False, the While Op will be terminated.")
        .AsDuplicable();
Y
Yang Yang(Tony) 已提交
104
    AddOutput(kOutputs,
Y
Yang Yang(Tony) 已提交
105
              "A set of variables, which will be assigned with values "
Y
Yang Yang(Tony) 已提交
106
              "generated by the operators inside the block of While Op.")
Y
Yang Yang(Tony) 已提交
107 108 109 110 111
        .AsDuplicable();
    AddOutput(kStepScopes,
              "(StepScopeVar) A vector of local scope, which size equals the "
              "step number of While Op. The i'th scope storages temporary "
              "variables generated in the i'th step.");
Y
Yu Yang 已提交
112 113
    AddAttr<framework::BlockDesc *>(kStepBlock,
                                    "The step block inside WhileOp");
114 115 116 117
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
S
sneaxiy 已提交
118
    AddAttr<std::vector<std::string>>(kSkipEagerDeletionVars,
S
fix bug  
sneaxiy 已提交
119 120 121
                                      "Vars that would skip eager deletion."
                                      "Users should not set this manually.")
        .SetDefault(std::vector<std::string>());
Y
Yang Yang(Tony) 已提交
122 123 124 125 126 127 128 129 130 131 132 133
    AddComment(R"DOC(
)DOC");
  }
};

class WhileGradOp : public framework::OperatorBase {
 public:
  WhileGradOp(const std::string &type, const framework::VariableNameMap &inputs,
              const framework::VariableNameMap &outputs,
              const framework::AttributeMap &attrs)
      : framework::OperatorBase(type, inputs, outputs, attrs) {}

134 135 136
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
C
chengduo 已提交
137 138
    PADDLE_ENFORCE(!Attr<bool>("is_test"),
                   "GradOp is only callable when is_test is false");
139 140 141
    // get device context from pool
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(dev_place);
D
dzhwinter 已提交
142
    framework::Executor executor(dev_place);
Y
Yu Yang 已提交
143
    auto *block = Attr<framework::BlockDesc *>(kStepBlock);
Y
Yang Yang(Tony) 已提交
144
    auto *program = block->Program();
S
sneaxiy 已提交
145 146 147 148 149 150

    auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
    if (framework::GetEagerDeletionThreshold() >= 0) {
      VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
    }
    auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
Y
Yang Yang(Tony) 已提交
151 152 153 154

    auto *step_scopes =
        scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();

Y
Yang Yang(Tony) 已提交
155 156 157 158 159 160
    auto outside_og_names = Inputs(framework::GradVarName(kOutputs));
    auto inside_og_names =
        Attr<std::vector<std::string>>("original_output_grad");

    PADDLE_ENFORCE_EQ(outside_og_names.size(), inside_og_names.size());

Y
Yang Yang(Tony) 已提交
161 162
    for (auto cur_scope_iter = step_scopes->rbegin();
         cur_scope_iter != step_scopes->rend(); ++cur_scope_iter) {
M
minqiyang 已提交
163 164
      VLOG(3) << "Start backward at time_step "
              << cur_scope_iter - step_scopes->rbegin();
Y
Yang Yang(Tony) 已提交
165 166 167 168 169
      framework::Scope &cur_scope = **cur_scope_iter;
      // Link OG from outside to inside
      for (size_t i = 0; i < outside_og_names.size(); ++i) {
        auto outside_og_name = outside_og_names[i];
        auto inside_og_name = inside_og_names[i];
M
minqiyang 已提交
170 171
        VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
                << inside_og_name;
C
chengduo 已提交
172 173 174 175
        if (scope.FindVar(outside_og_name) == nullptr) {
          continue;
        }

176 177 178 179 180 181
        auto &og_outside =
            detail::Ref(scope.FindVar(outside_og_name),
                        "Cannot find Outside Gradient %s", outside_og_name);
        auto &og_inside =
            detail::Ref(cur_scope.Var(inside_og_name),
                        "Cannot find inside gradient %s", inside_og_name);
S
sneaxiy 已提交
182
        if (framework::IsType<framework::LoDTensor>(og_outside.Type())) {
Y
Yang Yang(Tony) 已提交
183 184 185 186 187
          auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
          auto &inside_tensor =
              detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
          inside_tensor.set_lod(outside_tensor.lod());
          inside_tensor.ShareDataWith(outside_tensor);
S
sneaxiy 已提交
188 189
        } else if (framework::IsType<framework::LoDTensorArray>(
                       og_outside.Type())) {
Y
Yang Yang(Tony) 已提交
190 191 192
          auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
          auto &inside_array =
              detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
M
minqiyang 已提交
193
          VLOG(8) << outside_og_name << " size = " << outside_array.size();
Y
Yang Yang(Tony) 已提交
194 195 196
          inside_array.resize(outside_array.size());

          for (size_t j = 0; j < inside_array.size(); ++j) {
M
minqiyang 已提交
197
            VLOG(8) << j << " " << outside_array[j].numel();
Y
Yang Yang(Tony) 已提交
198 199 200 201 202 203 204
            if (outside_array[j].numel() != 0) {
              inside_array[j].set_lod(outside_array[j].lod());
              inside_array[j].ShareDataWith(outside_array[j]);
            } else {
              PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0);
            }
          }
C
chengduo 已提交
205 206
        } else {
          PADDLE_THROW("Currently only support LoDTensor and LoDTensorArray.");
Y
Yang Yang(Tony) 已提交
207 208
        }
      }
C
chengduoZH 已提交
209 210
      executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true,
                                  true);
Y
Yang Yang(Tony) 已提交
211

C
chengduo 已提交
212 213 214
      // The Outputs(kXGRAD) contains the names of the gradient of parameters
      // and inputs.
      auto &pg_ig_names = Outputs(kXGRAD);
Y
Yang Yu 已提交
215
      auto &p_names = Inputs(kX);
C
chengduo 已提交
216 217 218
      PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size());
      for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) {
        if (pg_ig_names[param_id] == framework::kEmptyVarName) {
219
          continue;  // parameter doesn't have gradient
Y
Yang Yang(Tony) 已提交
220 221
        }
        auto inside_grad_name = framework::GradVarName(p_names[param_id]);
Y
Yang Yang(Tony) 已提交
222

C
chengduo 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
        // for some grad_op, their input doesn't have gradient,
        // for example lookup_table_grad_op, the input(Idx) doesn't have
        // gradient.
        auto pg_ig_var = cur_scope.FindVar(inside_grad_name);
        PADDLE_ENFORCE(pg_ig_var != nullptr);
        if (pg_ig_var->IsType<framework::LoDTensorArray>()) {
          auto pg_ig_lod_t_arr =
              pg_ig_var->GetMutable<framework::LoDTensorArray>();
          bool empty = true;
          for (auto &each : *pg_ig_lod_t_arr) {
            if (each.numel() != 0) {
              empty = false;
              break;
            }
          }
          if (empty) {
            LOG(WARNING) << pg_ig_names[param_id]
                         << " is not found in cur_scope.";
            continue;
          }
        }

Y
Yang Yang(Tony) 已提交
245
        //  // TODO(tonyyang-svail): Not sure we need the following
Y
Yang Yang(Tony) 已提交
246 247 248 249 250 251 252 253 254 255 256
        //  // If does not compute gradient of that variable inside rnn,
        //  just
        //  // continue
        //  if (local_var_names.find(inside_grad_name) ==
        //  local_var_names.end()) {
        //    continue;
        //  }

        // zero gradient variable in step 0
        if (cur_scope_iter == step_scopes->rbegin()) {
          auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
Y
Yang Yang(Tony) 已提交
257
          PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name);
C
chengduoZH 已提交
258 259 260 261 262 263
          PADDLE_ENFORCE(
              var->IsType<framework::LoDTensorArray>() ||
                  var->IsType<LoDTensor>(),
              "Currently the type of var only can be LoDTensorArray, "
              "or LoDTensor, but the received var[%s] is %s.",
              inside_grad_name, var->Type().name());
C
chengduo 已提交
264

Y
Yang Yang(Tony) 已提交
265 266 267
          if (var->IsType<LoDTensor>()) {
            auto &inside_tensor = var->Get<framework::LoDTensor>();
            framework::AttributeMap attrs;
F
fengjiayi 已提交
268
            attrs["dtype"] = framework::ToDataType(inside_tensor.type());
Y
Yang Yang(Tony) 已提交
269 270 271
            attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
            attrs["value"] = 0.0f;

C
chengduo 已提交
272
            auto var_name = pg_ig_names[param_id];
Y
Yang Yang(Tony) 已提交
273
            auto zero_op = framework::OpRegistry::CreateOp(
Y
Yiqun Liu 已提交
274
                "fill_constant", framework::VariableNameMap{},
275
                {{"Out", {var_name}}}, attrs);
D
dzhwinter 已提交
276
            zero_op->Run(scope, dev_place);
277 278 279
            scope.FindVar(var_name)
                ->GetMutable<framework::LoDTensor>()
                ->set_lod(inside_tensor.lod());
Y
Yang Yang(Tony) 已提交
280 281
          }
        }
Y
Yang Yang(Tony) 已提交
282
        auto new_inside_name = cur_scope.Rename(inside_grad_name);
Y
Yang Yang(Tony) 已提交
283
        auto sum_op = framework::OpRegistry::CreateOp(
C
chengduo 已提交
284 285
            "sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
            {{"Out", {pg_ig_names[param_id]}}},
286
            framework::AttributeMap{{"use_mkldnn", {false}}});
D
dzhwinter 已提交
287
        sum_op->Run(cur_scope, dev_place);
Y
Yang Yang(Tony) 已提交
288
        cur_scope.Rename(new_inside_name, inside_grad_name);
Y
Yang Yang(Tony) 已提交
289
      }
290 291
      dev_ctx.Wait();
      const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
Y
Yang Yang(Tony) 已提交
292 293 294 295 296 297 298 299 300
    }
  }
};

class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
301
  std::unique_ptr<framework::OpDesc> Apply() const override {
F
Update  
fengjiayi 已提交
302 303 304 305 306 307 308
    auto *while_grad = new framework::OpDesc();
    while_grad->SetType("while_grad");
    while_grad->SetInput(kX, Input(kX));
    while_grad->SetInput(kOutputs, Output(kOutputs));
    while_grad->SetInput(kStepScopes, Output(kStepScopes));

    auto *grad_block = this->grad_block_[0];
Y
Yu Yang 已提交
309 310
    auto *fwd_block = grad_block->ForwardBlock();
    auto *parent_block = grad_block->ParentBlock();
311 312 313

    // Not all of IGs will be generated by inner gradient operators of while op.
    // Ignore IGs that is not generated by the inside block.
F
Update  
fengjiayi 已提交
314 315 316 317
    std::unordered_set<std::string> inner_op_outputs;
    for (const auto *op : grad_block->AllOps()) {
      for (auto &oname : op->OutputArgumentNames()) {
        inner_op_outputs.insert(oname);
318 319
      }
    }
F
Update  
fengjiayi 已提交
320
    auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
321
    for (auto &each_ig : igs) {
F
Update  
fengjiayi 已提交
322
      if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
M
minqiyang 已提交
323
        VLOG(8) << "Ignore " << each_ig;
324 325 326
        each_ig = framework::kEmptyVarName;
      }
    }
F
Update  
fengjiayi 已提交
327
    while_grad->SetOutput(framework::GradVarName(kX), igs);
Y
Yang Yang(Tony) 已提交
328 329 330 331

    // OG should be re-calculated by step blocks, since many outputs of while op
    // do not need to calculate gradients.
    std::unordered_set<std::string> block_ins;
F
fengjiayi 已提交
332 333 334 335 336 337 338
    block_ins.reserve(Input(kX).size() + Output(kOutputs).size());
    for (auto &p : Input(kX)) {
      block_ins.insert(p);
    }
    for (auto &o : Output(kOutputs)) {
      block_ins.insert(o);
    }
Y
Yu Yang 已提交
339
    std::unordered_set<std::string> output_grads;
F
Update  
fengjiayi 已提交
340 341 342 343
    for (const auto *op : grad_block->AllOps()) {
      for (auto &input_name : op->InputArgumentNames()) {
        // If the input of Op has been recorded or is generated by the forward
        // block, do not make it as input again.
Y
Yu Yang 已提交
344 345 346

        // The input is located in I/O or other op's outputs or the variable is
        // located in grad_block's parents
F
Update  
fengjiayi 已提交
347
        if (block_ins.find(input_name) != block_ins.end() ||
Y
Yu Yang 已提交
348 349
            (fwd_block->FindVarRecursive(input_name) != nullptr ||
             parent_block->FindVarRecursive(input_name) != nullptr)) {
Y
Yang Yang(Tony) 已提交
350 351
          continue;
        }
C
chengduo 已提交
352

Y
Yu Yang 已提交
353
        output_grads.insert(input_name);
Y
Yang Yang(Tony) 已提交
354
      }
F
Update  
fengjiayi 已提交
355
      for (auto &output_name : op->OutputArgumentNames()) {
Y
Yang Yang(Tony) 已提交
356
        block_ins.insert(output_name);
Y
Yang Yang(Tony) 已提交
357 358
      }
    }
Y
Yang Yang(Tony) 已提交
359

Y
Yu Yang 已提交
360 361 362 363 364
    std::vector<std::string> output_grads_list;
    output_grads_list.resize(output_grads.size());
    std::copy(output_grads.begin(), output_grads.end(),
              output_grads_list.begin());
    while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
F
Update  
fengjiayi 已提交
365 366

    while_grad->SetAttrMap(this->Attrs());
A
Abhinav Arora 已提交
367
    while_grad->SetBlockAttr(kStepBlock, grad_block);
Y
Yang Yang(Tony) 已提交
368 369
    // record the original output gradient names, since the gradient name of
    // while operator could be renamed.
Y
Yu Yang 已提交
370
    while_grad->SetAttr("original_output_grad", output_grads_list);
Y
Yang Yang(Tony) 已提交
371

S
sneaxiy 已提交
372 373
    /* The followi_ng codes are used in eager deletion mode */
    std::unordered_set<std::string> bwd_skip_vars;
S
fix bug  
sneaxiy 已提交
374
    if (framework::GetEagerDeletionThreshold() >= 0) {
S
sneaxiy 已提交
375
      std::unordered_set<std::string> fwd_skip_vars;
S
fix bug  
sneaxiy 已提交
376
      for (auto *op_desc : grad_block->AllOps()) {
S
sneaxiy 已提交
377 378 379 380 381
        auto skippable = [&](const std::string &name) {
          return !grad_block->HasVar(name) &&
                 (fwd_block->HasVarRecursive(name) ||
                  parent_block->HasVarRecursive(name));
        };
S
fix bug  
sneaxiy 已提交
382
        for (auto &in_arg_name : op_desc->InputArgumentNames()) {
S
sneaxiy 已提交
383 384 385 386 387 388 389 390
          if (skippable(in_arg_name)) {
            fwd_skip_vars.insert(in_arg_name);
          }
        }

        for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
          if (skippable(out_arg_name)) {
            fwd_skip_vars.insert(out_arg_name);
S
fix bug  
sneaxiy 已提交
391 392 393 394
          }
        }
      }

S
sneaxiy 已提交
395
      if (!fwd_skip_vars.empty()) {
S
fix bug  
sneaxiy 已提交
396 397 398
        // FIXME(zjl): ugly const_cast here, maybe we should find a better way
        // to modify forward while_op
        auto &fwd_while_op = const_cast<framework::OpDesc &>(ForwardOp());
S
sneaxiy 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411
        fwd_while_op.SetAttr(kSkipEagerDeletionVars,
                             std::vector<std::string>(fwd_skip_vars.begin(),
                                                      fwd_skip_vars.end()));
      }

      // Find backward skip vars
      auto fwd_input = Input(kX);
      for (size_t i = 0; i < igs.size(); ++i) {
        if (igs[i] == framework::kEmptyVarName) {
          continue;
        }
        bwd_skip_vars.insert(igs[i]);
        bwd_skip_vars.insert(framework::GradVarName(fwd_input[i]));
S
fix bug  
sneaxiy 已提交
412 413
      }
    }
S
sneaxiy 已提交
414 415 416
    while_grad->SetAttr(
        kSkipEagerDeletionVars,
        std::vector<std::string>(bwd_skip_vars.begin(), bwd_skip_vars.end()));
S
fix bug  
sneaxiy 已提交
417

F
Update  
fengjiayi 已提交
418
    return std::unique_ptr<framework::OpDesc>(while_grad);
Y
Yang Yang(Tony) 已提交
419 420 421
  }
};

Y
Yang Yang(Tony) 已提交
422 423
class WhileGradOpVarTypeInference : public framework::VarTypeInference {
 public:
Y
Yu Yang 已提交
424 425
  void operator()(const framework::OpDesc &op_desc,
                  framework::BlockDesc *block) const override {
Y
Yang Yu 已提交
426
    auto p_names = op_desc.Input(kX);
C
chengduo 已提交
427
    auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
Y
Yang Yang(Tony) 已提交
428 429 430

    for (size_t i = 0; i < p_names.size(); ++i) {
      auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
C
chengduo 已提交
431
      auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
Y
Yang Yang(Tony) 已提交
432
      if (g_var != nullptr) {  // Gradient could be @EMPTY@
M
minqiyang 已提交
433 434
        VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
                << " type: " << p_var.GetType();
Y
Yang Yang(Tony) 已提交
435 436 437 438 439 440 441 442 443 444
        g_var->SetType(p_var.GetType());
        g_var->SetDataType(p_var.GetDataType());
      }
    }
  }
};

class WhileGradOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {
Y
Yang Yu 已提交
445 446
    ctx->HasInputs(kX);
    ctx->HasOutputs(framework::GradVarName(kX));
Y
Yang Yang(Tony) 已提交
447 448 449
    ctx->HasInputs(kOutputs);
    ctx->HasInputs(framework::GradVarName(kOutputs));

Y
Yang Yu 已提交
450
    auto p_names = ctx->Inputs(kX);
C
chengduo 已提交
451
    auto pg_ig_names = ctx->Outputs(kXGRAD);
Y
Yang Yu 已提交
452
    auto var_types = ctx->GetInputsVarType(kX);
Y
Yang Yang(Tony) 已提交
453 454 455
    std::vector<std::string> names_to_set;
    std::vector<framework::DDim> dims_to_set;
    for (size_t i = 0; i < p_names.size(); ++i) {
C
chengduo 已提交
456
      if (pg_ig_names[i] == framework::kEmptyVarName) {
Y
Yang Yang(Tony) 已提交
457 458
        continue;
      }
Y
Yang Yu 已提交
459
      auto dims = ctx->GetInputsElementDim(kX, i);
460
      if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
C
chengduo 已提交
461
        names_to_set.push_back(pg_ig_names[i]);
F
fengjiayi 已提交
462
        dims_to_set.push_back(dims);
463
      } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
Y
Yang Yang(Tony) 已提交
464
        // not sure how to set the dim of LOD_TENSOR_ARRAY
C
chengduo 已提交
465
        names_to_set.push_back(pg_ig_names[i]);
F
fengjiayi 已提交
466
        dims_to_set.push_back(dims);
Y
Yang Yang(Tony) 已提交
467 468 469 470 471 472
      }
    }
    ctx->SetDims(names_to_set, dims_to_set);
  }
};

Y
Yang Yang(Tony) 已提交
473 474 475 476 477 478
}  // namespace operators
}  // namespace paddle

REGISTER_OPERATOR(while, paddle::operators::WhileOp,
                  paddle::operators::WhileOpMaker,
                  paddle::operators::WhileGradOpDescMaker);
Y
Yang Yang(Tony) 已提交
479 480 481
REGISTER_OPERATOR(while_grad, paddle::operators::WhileGradOp,
                  paddle::operators::WhileGradOpShapeInference,
                  paddle::operators::WhileGradOpVarTypeInference);