layer.cc 17.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17 18 19 20
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
21
#include <unordered_set>
22 23 24 25
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
26
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
27 28 29
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
30 31 32 33 34
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

X
polish  
Xin Pan 已提交
35 36
const char* PyLayer::kFwdInp = "X";
const char* PyLayer::kFwdOut = "Out";
X
polish  
Xin Pan 已提交
37

X
Xin Pan 已提交
38 39
std::map<int, py::object> py_funcs_;

40 41
using framework::Variable;

M
minqiyang 已提交
42 43 44 45 46 47 48 49 50 51 52
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
53
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
54 55 56 57 58 59 60 61
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
62
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

P
Paddle CI 已提交
84
void AddTo(Variable* src, Variable* dst, platform::Place place) {
M
minqiyang 已提交
85 86 87
  framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();

M
minqiyang 已提交
88 89 90 91 92
  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (src_tensor->numel() == 0) {
    return;
  }
M
minqiyang 已提交
93

94 95 96
  PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                 "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                 src_tensor->numel());
M
minqiyang 已提交
97 98 99 100 101

  detail::TensorAddToFunctor<float> func(
      src_tensor->numel(), src_tensor->data<float>(),
      dst_tensor->mutable_data<float>(place));
  boost::apply_visitor(func, place);
102 103
}

104 105 106 107 108 109 110 111 112 113 114
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
  PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
                 "Can't find %s in backward grad map", target->Name());
  std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
      bck_map->at(target);
  std::sort(
      current.second.begin(), current.second.end(),
      [](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
        return a.first > b.first;
      });
  for (auto& var_pair : current.second) {
115 116
    Variable* origin_grad = target->var_.get();
    Variable* grad_to_add = var_pair.second->var_.get();
117 118 119 120
    VLOG(2) << "add origin_grad: " << target->Name();
    VLOG(2) << "added grad: " << var_pair.second->Name()
            << " trace id is: " << var_pair.first;
    AddTo(grad_to_add, origin_grad, current.first);
121
    delete var_pair.second;
122 123 124 125
    var_pair.second = nullptr;
  }
}

126 127
class Autograd {
 public:
X
Xin Pan 已提交
128
  Autograd() {}
129

130
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
131
    if (var->IsStopGradient()) {
132 133
      return;
    }
X
Xin Pan 已提交
134
    VLOG(3) << "start autograd";
135 136
    BackwardSumMap bck_map;
    GradientRef grad_ref;
137
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
138
    ready.push_back(var->PreOp());
139

140
    std::map<OpBase*, int> dep_counts =
141
        ComputeDepCounts(var->PreOp(), bck_stratedy, &grad_ref);
142 143 144 145

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
X
Xin Pan 已提交
146
      std::map<std::string, std::vector<VarBase*>> input_grads =
147
          ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
X
Xin Pan 已提交
148

149 150
      for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
        const std::vector<VarBase*>& ingrads = it->second;
X
Xin Pan 已提交
151 152
        for (size_t i = 0; i < ingrads.size(); ++i) {
          if (!ingrads[i]) continue;
Y
Yan Xu 已提交
153 154 155
          auto p = ready_op->input_vars_[it->first][i];

          if (p->IsStopGradient()) continue;
156
          OpBase* pre_op = ready_op->pre_ops_[it->first][i];
X
Xin Pan 已提交
157 158 159 160 161 162 163 164
          if (!pre_op) continue;

          dep_counts[pre_op] -= 1;
          PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
          bool pre_op_ready = dep_counts[pre_op] == 0;
          if (pre_op_ready) {
            ready.push_back(pre_op);
          }
165 166
        }
      }
167 168

      ready_op->InvokeBackwardHooks();
169 170 171 172
    }
  }

 private:
173
  std::map<OpBase*, int> ComputeDepCounts(
174 175 176 177 178 179 180
      OpBase* op, const detail::BackwardStrategy& bck_stratedy,
      GradientRef* grad_ref) {
    if (bck_stratedy.sorted_sum_gradient_) {
      PADDLE_ENFORCE_NOT_NULL(grad_ref,
                              "grad_ref should not be null when "
                              "using sorted grad backward strategy");
    }
181 182 183 184 185 186 187 188 189
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
190 191 192 193
      if (bck_stratedy.sorted_sum_gradient_) {
        for (const auto& map : candidate->grad_output_vars_) {
          for (const auto& it : map) {
            for (const auto& vb : it.second) {
194
              ++(*grad_ref)[vb];
195 196 197 198
            }
          }
        }
      }
X
Xin Pan 已提交
199
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
200 201
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
202
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
203
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
204
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
205 206 207 208 209
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
210 211 212 213 214 215 216
        }
      }
    }
    return ret;
  }
};

M
minqiyang 已提交
217 218
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
219 220 221
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

222 223 224 225
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
226 227 228
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
229

230 231
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
232
  if (blocking) {
233 234 235 236 237
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
238 239 240
  }

  if (platform::is_gpu_place(dst_place)) {
241
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
242 243
  }

P
Paddle CI 已提交
244
  return new_var;
M
minqiyang 已提交
245 246
}

M
minqiyang 已提交
247
framework::LoDTensor& VarBase::GradValue() {
248 249 250
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
251
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
252 253
}

254 255 256
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
257 258
  PADDLE_ENFORCE(!grad_op_descs_.empty() || backward_id_ > 0,
                 "%s has no backward implementation", Type());
259

260
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
261
  std::vector<VarBasePtrMap> tmp_grad_outputs;
X
Xin Pan 已提交
262 263
  if (backward_id_ > 0) {
    VLOG(3) << "py_layer_grad";
264 265
    tmp_grad_outputs.resize(1);
    tmp_grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
X
Xin Pan 已提交
266 267 268
        PyLayer::ApplyGrad(
            backward_id_,
            grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
X
Xin Pan 已提交
269
  } else {
270 271 272 273
    const size_t grad_op_count = grad_op_descs_.size();

    tmp_grad_outputs.resize(grad_op_count);
    for (size_t k = 0; k < grad_op_count; ++k) {
X
Xin Pan 已提交
274
      framework::OpDesc* grad_op_desc = grad_op_descs_[k];
275 276 277 278 279
      auto& grad_output_variable_map = grad_output_vars_[k];

      VLOG(3) << "apply grad op " << grad_op_desc->Type();

      // Allocate tmp grad output variable
M
minqiyang 已提交
280
      for (const auto& it : grad_output_variable_map) {
281 282
        auto& outputs = tmp_grad_outputs[k][it.first];
        outputs.reserve(it.second.size());
X
Xin Pan 已提交
283
        for (size_t i = 0; i < it.second.size(); ++i) {
M
minqiyang 已提交
284 285
          VarBase* origin_grad_var_base = it.second[i];

X
Xin Pan 已提交
286
          // Allocate a new variable
M
minqiyang 已提交
287 288 289 290 291
          VarBase* tmp_grad_var_base = new VarBase(
              string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
              origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
              place_, true, false);
          outputs.emplace_back(tmp_grad_var_base);
X
Xin Pan 已提交
292
        }
X
polish  
Xin Pan 已提交
293
      }
294

X
Xin Pan 已提交
295 296
      // No need to do compile time infer shape here.
      // grad_op_desc_->InferShape(*block_);
297
      // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
298

X
Xin Pan 已提交
299 300
      std::unique_ptr<framework::OperatorBase> opbase =
          framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
301

M
minqiyang 已提交
302 303 304 305 306 307
      auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
      if (info.infer_var_type_) {
        RuntimeInferVarTypeContext infer_var_type_ctx(
            &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
        info.infer_var_type_(&infer_var_type_ctx);
      }
M
minqiyang 已提交
308

X
Xin Pan 已提交
309 310 311
      framework::OperatorWithKernel* op_kernel =
          dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
      PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
312

M
minqiyang 已提交
313 314 315 316 317 318 319 320 321 322 323
      // Run grad op
      framework::VariableValueMap grad_invars_map;
      framework::VariableValueMap grad_outvars_map;

      for (const auto& it : grad_input_vars_[k]) {
        auto& grad_invars = grad_invars_map[it.first];
        grad_invars.reserve(it.second.size());
        for (const VarBase* grad_inp : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                  grad_op_desc->Type(), grad_inp->Name());

324
          grad_invars.emplace_back(grad_inp->var_.get());
M
minqiyang 已提交
325 326 327 328 329 330 331 332 333 334
        }
      }

      for (const auto& it : tmp_grad_outputs[k]) {
        auto& grad_outvars = grad_outvars_map[it.first];
        grad_outvars.reserve(it.second.size());
        for (VarBase* grad_out : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                  grad_op_desc->Type(), grad_out->Name());

335
          grad_outvars.emplace_back(grad_out->var_.get());
M
minqiyang 已提交
336 337 338 339
        }
      }

      framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
X
Xin Pan 已提交
340 341 342
      framework::Scope scope;
      PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
      p.op.RuntimeInferShape(scope, place_, ctx);
343 344
      p.func(
          framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
345
    }
X
Xin Pan 已提交
346
  }
X
Xin Pan 已提交
347

348
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
349
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
350
    for (const auto& it : grad_output_vars_[k]) {
351
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
352
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
353 354 355
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA
          VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " ";
          VLOG(2) << origin_outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
          VLOG(2) << "outputs is : " << outputs[i]->Name() << " ";
          VLOG(2) << outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
#endif
          if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
            VLOG(2) << "add sub grad to " << origin_outputs[i]->Name();
            bck_map->at(origin_outputs[i])
                .second.emplace_back(
                    std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
          } else {
            VLOG(2) << "insert new map for " << origin_outputs[i]->Name();
            std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
                tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
            bck_map->insert(std::make_pair(origin_outputs[i], tmp));
          }

          PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
                         "Can't find  %s in grad_reference count map",
                         origin_outputs[i]->Name());
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
                         "Backward error when calculate grad reference");
          if (grad_ref->at(origin_outputs[i]) > 1) {
            VLOG(2) << "remove ref for " << origin_outputs[i]->Name();
            grad_ref->at(origin_outputs[i])--;
          } else {
            VLOG(2) << "Add grad for: " << origin_outputs[i]->Name();
            AddGradBySort(bck_map, origin_outputs[i]);
            grad_ref->at(origin_outputs[i])--;
          }
        } else {
394 395
          framework::Variable* grad = outputs[i]->var_.get();
          framework::Variable* orig_grad = origin_outputs[i]->var_.get();
396 397 398 399
          VLOG(2) << "AddTo Called with orig_grad is: "
                  << origin_outputs[i]->name_ << " Grad to be added is "
                  << outputs[i]->name_;
          AddTo(grad, orig_grad, place_);
400
          delete outputs[i];
401
        }
X
Xin Pan 已提交
402
      }
403 404
    }
  }
X
Xin Pan 已提交
405

X
Xin Pan 已提交
406
  return input_vars_;
407 408
}

409
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
410
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
411 412 413 414 415 416 417

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

Y
Yan Xu 已提交
418
void OpBase::RegisterBackwardHooks(const py::object& callable) {
M
minqiyang 已提交
419
  VLOG(3) << "Register backward hooks " << trace_id_;
420 421

  // TODO(minqiyang): check the callable format
Y
Yan Xu 已提交
422
  backward_hooks_.push_back(callable);
423 424
}

425
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
426
  if (!pre_op_) return;
X
Xin Pan 已提交
427

X
Xin Pan 已提交
428
  VLOG(3) << "start backward";
M
minqiyang 已提交
429
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
430 431 432 433
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
434

X
Xin Pan 已提交
435 436 437
  PADDLE_ENFORCE(
      grads_ ==
      pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
438
  Autograd().RunBackward(this, bck_stratedy);
439 440
}

X
Xin Pan 已提交
441 442 443 444
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
  py_funcs_[func_id] = py_func;
}

X
polish  
Xin Pan 已提交
445 446
int PyLayer::NumFuncs() { return py_funcs_.size(); }

447
std::vector<std::unique_ptr<framework::Variable>> PyLayer::Apply(
M
minqiyang 已提交
448
    int func_id, const std::vector<VarBase*>& inputs) {
X
Xin Pan 已提交
449
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
450
  return CallPythonFunc(py_funcs_[func_id], inputs);
X
Xin Pan 已提交
451 452
}

M
minqiyang 已提交
453 454
std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
                                         const std::vector<VarBase*>& inputs) {
X
polish  
Xin Pan 已提交
455
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
456 457 458 459 460 461 462 463
  auto rets = CallPythonFunc(py_funcs_[func_id], inputs);

  std::vector<VarBase*> outs;
  outs.reserve(rets.size());
  for (size_t i = 0U; i != rets.size(); ++i) {
    outs.emplace_back(new VarBase(
        string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
                        i),
464
        std::move(rets[i]), nullptr, true));
M
minqiyang 已提交
465 466 467
  }

  return outs;
X
polish  
Xin Pan 已提交
468
}
X
Xin Pan 已提交
469

470
std::vector<std::unique_ptr<framework::Variable>> PyLayer::CallPythonFunc(
M
minqiyang 已提交
471
    const py::object& callable, const std::vector<VarBase*>& ins) {
X
polish  
Xin Pan 已提交
472 473 474
  py::gil_scoped_acquire guard;
  py::tuple in_args(ins.size());
  for (size_t i = 0; i < ins.size(); ++i) {
M
minqiyang 已提交
475
    const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
X
polish  
Xin Pan 已提交
476
    in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
X
Xin Pan 已提交
477
  }
X
polish  
Xin Pan 已提交
478 479 480 481 482 483
  VLOG(3) << "pyfunc in " << py::len(in_args);

  // TODO(panyx0718): Who owns the returned LoDTensor.
  auto ret = callable(in_args);
  auto ret_tuple = py::cast<py::tuple>(ret);
  size_t ret_num = py::len(ret_tuple);
484
  std::vector<std::unique_ptr<framework::Variable>> outs;
M
minqiyang 已提交
485
  outs.reserve(ret_num);
X
polish  
Xin Pan 已提交
486 487 488 489 490 491
  VLOG(3) << "pyfunc out " << ret_num;
  for (size_t i = 0; i < ret_num; ++i) {
    try {
      auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
      PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
                              "Output tensor %d should not be nullptr", i);
492 493
      auto var =
          std::unique_ptr<framework::Variable>(new framework::Variable());
X
polish  
Xin Pan 已提交
494 495 496
      auto* tensor = var->GetMutable<framework::LoDTensor>();
      tensor->ShareDataWith(*py_out_tensor);
      tensor->set_lod(py_out_tensor->lod());
497
      outs.emplace_back(std::move(var));
X
polish  
Xin Pan 已提交
498 499 500 501 502
    } catch (py::cast_error&) {
      PADDLE_THROW("The %d-th output must be LoDTensor", i);
    }
  }
  return outs;
X
Xin Pan 已提交
503 504
}

505 506
}  // namespace imperative
}  // namespace paddle