layer.cc 17.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17 18 19 20
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
21
#include <unordered_set>
22 23 24 25
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
26
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
27 28 29
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
30 31 32 33 34
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

X
polish  
Xin Pan 已提交
35 36
const char* PyLayer::kFwdInp = "X";
const char* PyLayer::kFwdOut = "Out";
X
polish  
Xin Pan 已提交
37

X
Xin Pan 已提交
38 39
std::map<int, py::object> py_funcs_;

40 41
using framework::Variable;

M
minqiyang 已提交
42 43 44 45 46 47 48 49 50 51 52
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
53
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
54 55 56 57 58 59 60 61
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
62
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

P
Paddle CI 已提交
84
void AddTo(Variable* src, Variable* dst, platform::Place place) {
M
minqiyang 已提交
85 86 87
  framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();

M
minqiyang 已提交
88 89 90 91 92
  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (src_tensor->numel() == 0) {
    return;
  }
M
minqiyang 已提交
93

94 95 96
  PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                 "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                 src_tensor->numel());
M
minqiyang 已提交
97 98 99 100 101

  detail::TensorAddToFunctor<float> func(
      src_tensor->numel(), src_tensor->data<float>(),
      dst_tensor->mutable_data<float>(place));
  boost::apply_visitor(func, place);
102 103
}

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
  PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
                 "Can't find %s in backward grad map", target->Name());
  std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
      bck_map->at(target);
  std::sort(
      current.second.begin(), current.second.end(),
      [](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
        return a.first > b.first;
      });
  for (auto& var_pair : current.second) {
    Variable* origin_grad = target->var_;
    Variable* grad_to_add = var_pair.second->var_;
    VLOG(2) << "add origin_grad: " << target->Name();
    VLOG(2) << "added grad: " << var_pair.second->Name()
            << " trace id is: " << var_pair.first;
    AddTo(grad_to_add, origin_grad, current.first);
    delete grad_to_add;
    var_pair.second = nullptr;
  }
}

126 127
class Autograd {
 public:
X
Xin Pan 已提交
128
  Autograd() {}
129

130
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
131
    if (var->IsStopGradient()) {
132 133
      return;
    }
X
Xin Pan 已提交
134
    VLOG(3) << "start autograd";
135 136
    bck_map = new BackwardSumMap();
    grad_ref = new GradientRef();
137
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
138
    ready.push_back(var->PreOp());
139

140 141
    std::map<OpBase*, int> dep_counts =
        ComputeDepCounts(var->PreOp(), bck_stratedy);
142 143 144 145

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
X
Xin Pan 已提交
146
      std::map<std::string, std::vector<VarBase*>> input_grads =
147
          ready_op->ApplyGrad(bck_map, grad_ref, bck_stratedy);
X
Xin Pan 已提交
148

149 150
      for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
        const std::vector<VarBase*>& ingrads = it->second;
X
Xin Pan 已提交
151 152
        for (size_t i = 0; i < ingrads.size(); ++i) {
          if (!ingrads[i]) continue;
153
          if (ready_op->input_vars_[it->first][i]->IsStopGradient()) {
154 155
            continue;
          }
156
          OpBase* pre_op = ready_op->pre_ops_[it->first][i];
X
Xin Pan 已提交
157 158 159 160 161 162 163 164
          if (!pre_op) continue;

          dep_counts[pre_op] -= 1;
          PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
          bool pre_op_ready = dep_counts[pre_op] == 0;
          if (pre_op_ready) {
            ready.push_back(pre_op);
          }
165 166
        }
      }
167 168

      ready_op->InvokeBackwardHooks();
169 170 171 172
    }
  }

 private:
173 174
  std::map<OpBase*, int> ComputeDepCounts(
      OpBase* op, const detail::BackwardStrategy& bck_stratedy) {
175 176 177 178 179 180 181 182 183
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
      if (bck_stratedy.sorted_sum_gradient_) {
        for (const auto& map : candidate->grad_output_vars_) {
          for (const auto& it : map) {
            for (const auto& vb : it.second) {
              if (grad_ref->find(vb) == grad_ref->end()) {
                grad_ref->insert(std::make_pair(vb, 1));
              } else {
                // add ref count by 1 when we find grad_var can be generated by
                // one grad_op
                grad_ref->at(vb) += 1;
              }
            }
          }
        }
      }
X
Xin Pan 已提交
199
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
200 201
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
202
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
203
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
204
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
205 206 207 208 209
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
210 211 212 213 214
        }
      }
    }
    return ret;
  }
215 216 217

  BackwardSumMap* bck_map;
  GradientRef* grad_ref;
218 219
};

M
minqiyang 已提交
220 221
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
222 223 224
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

225 226 227 228
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
229 230 231
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
232

P
Paddle CI 已提交
233
  if (blocking) {
M
minqiyang 已提交
234
    platform::DeviceContext* dev_ctx =
P
Paddle CI 已提交
235 236 237 238 239
        platform::DeviceContextPool::Instance().Get(dst_place);

    framework::TensorCopySync(var_->Get<framework::LoDTensor>(), dst_place,
                              tensor);

M
minqiyang 已提交
240 241
    dev_ctx->Wait();
  } else {
P
Paddle CI 已提交
242 243 244 245
    framework::TensorCopy(var_->Get<framework::LoDTensor>(), dst_place, tensor);
  }

  if (platform::is_gpu_place(dst_place)) {
246
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
247 248
  }

P
Paddle CI 已提交
249
  return new_var;
M
minqiyang 已提交
250 251
}

M
minqiyang 已提交
252
framework::LoDTensor& VarBase::GradValue() {
253 254 255
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
256
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
257 258
}

259 260 261
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
262 263
  PADDLE_ENFORCE(!grad_op_descs_.empty() || backward_id_ > 0,
                 "%s has no backward implementation", Type());
264

265
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
266
  std::vector<VarBasePtrMap> tmp_grad_outputs;
X
Xin Pan 已提交
267 268
  if (backward_id_ > 0) {
    VLOG(3) << "py_layer_grad";
269 270
    tmp_grad_outputs.resize(1);
    tmp_grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
X
Xin Pan 已提交
271 272 273
        PyLayer::ApplyGrad(
            backward_id_,
            grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
X
Xin Pan 已提交
274
  } else {
275 276 277 278
    const size_t grad_op_count = grad_op_descs_.size();

    tmp_grad_outputs.resize(grad_op_count);
    for (size_t k = 0; k < grad_op_count; ++k) {
X
Xin Pan 已提交
279
      framework::OpDesc* grad_op_desc = grad_op_descs_[k];
280 281 282 283 284
      auto& grad_output_variable_map = grad_output_vars_[k];

      VLOG(3) << "apply grad op " << grad_op_desc->Type();

      // Allocate tmp grad output variable
M
minqiyang 已提交
285
      for (const auto& it : grad_output_variable_map) {
286 287
        auto& outputs = tmp_grad_outputs[k][it.first];
        outputs.reserve(it.second.size());
X
Xin Pan 已提交
288
        for (size_t i = 0; i < it.second.size(); ++i) {
M
minqiyang 已提交
289 290
          VarBase* origin_grad_var_base = it.second[i];

X
Xin Pan 已提交
291
          // Allocate a new variable
M
minqiyang 已提交
292 293 294 295 296
          VarBase* tmp_grad_var_base = new VarBase(
              string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
              origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
              place_, true, false);
          outputs.emplace_back(tmp_grad_var_base);
X
Xin Pan 已提交
297
        }
X
polish  
Xin Pan 已提交
298
      }
299

X
Xin Pan 已提交
300 301
      // No need to do compile time infer shape here.
      // grad_op_desc_->InferShape(*block_);
302
      // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
303

X
Xin Pan 已提交
304 305
      std::unique_ptr<framework::OperatorBase> opbase =
          framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
306

M
minqiyang 已提交
307 308 309 310 311 312
      auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
      if (info.infer_var_type_) {
        RuntimeInferVarTypeContext infer_var_type_ctx(
            &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
        info.infer_var_type_(&infer_var_type_ctx);
      }
M
minqiyang 已提交
313

X
Xin Pan 已提交
314 315 316
      framework::OperatorWithKernel* op_kernel =
          dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
      PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
317

M
minqiyang 已提交
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
      // Run grad op
      framework::VariableValueMap grad_invars_map;
      framework::VariableValueMap grad_outvars_map;

      for (const auto& it : grad_input_vars_[k]) {
        auto& grad_invars = grad_invars_map[it.first];
        grad_invars.reserve(it.second.size());
        for (const VarBase* grad_inp : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                  grad_op_desc->Type(), grad_inp->Name());

          grad_invars.emplace_back(grad_inp->var_);
        }
      }

      for (const auto& it : tmp_grad_outputs[k]) {
        auto& grad_outvars = grad_outvars_map[it.first];
        grad_outvars.reserve(it.second.size());
        for (VarBase* grad_out : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                  grad_op_desc->Type(), grad_out->Name());

          grad_outvars.emplace_back(grad_out->var_);
        }
      }

      framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
X
Xin Pan 已提交
345 346 347
      framework::Scope scope;
      PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
      p.op.RuntimeInferShape(scope, place_, ctx);
348 349
      p.func(
          framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
350
    }
X
Xin Pan 已提交
351
  }
X
Xin Pan 已提交
352

353
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
354
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
355
    for (const auto& it : grad_output_vars_[k]) {
356
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
357
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
358 359 360
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA
          VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " ";
          VLOG(2) << origin_outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
          VLOG(2) << "outputs is : " << outputs[i]->Name() << " ";
          VLOG(2) << outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
#endif
          if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
            VLOG(2) << "add sub grad to " << origin_outputs[i]->Name();
            bck_map->at(origin_outputs[i])
                .second.emplace_back(
                    std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
          } else {
            VLOG(2) << "insert new map for " << origin_outputs[i]->Name();
            std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
                tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
            bck_map->insert(std::make_pair(origin_outputs[i], tmp));
          }

          PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
                         "Can't find  %s in grad_reference count map",
                         origin_outputs[i]->Name());
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
                         "Backward error when calculate grad reference");
          if (grad_ref->at(origin_outputs[i]) > 1) {
            VLOG(2) << "remove ref for " << origin_outputs[i]->Name();
            grad_ref->at(origin_outputs[i])--;
          } else {
            VLOG(2) << "Add grad for: " << origin_outputs[i]->Name();
            AddGradBySort(bck_map, origin_outputs[i]);
            grad_ref->at(origin_outputs[i])--;
          }
        } else {
          framework::Variable* grad = outputs[i]->var_;
          framework::Variable* orig_grad = origin_outputs[i]->var_;
          VLOG(2) << "AddTo Called with orig_grad is: "
                  << origin_outputs[i]->name_ << " Grad to be added is "
                  << outputs[i]->name_;
          AddTo(grad, orig_grad, place_);
          delete grad;
        }
X
Xin Pan 已提交
407
      }
408 409
    }
  }
X
Xin Pan 已提交
410

X
Xin Pan 已提交
411
  return input_vars_;
412 413
}

414
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
415
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
416 417 418 419 420 421 422

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

423
void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) {
M
minqiyang 已提交
424
  VLOG(3) << "Register backward hooks " << trace_id_;
425 426

  // TODO(minqiyang): check the callable format
427 428 429 430 431
  if (front) {
    backward_hooks_.insert(backward_hooks_.begin(), callable);
  } else {
    backward_hooks_.push_back(callable);
  }
432 433
}

434
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
435
  if (!pre_op_) return;
X
Xin Pan 已提交
436

X
Xin Pan 已提交
437
  VLOG(3) << "start backward";
M
minqiyang 已提交
438
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
439 440 441 442
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
443

X
Xin Pan 已提交
444 445 446
  PADDLE_ENFORCE(
      grads_ ==
      pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
447
  Autograd().RunBackward(this, bck_stratedy);
448 449
}

X
Xin Pan 已提交
450 451 452 453
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
  py_funcs_[func_id] = py_func;
}

X
polish  
Xin Pan 已提交
454 455
int PyLayer::NumFuncs() { return py_funcs_.size(); }

M
minqiyang 已提交
456 457
std::vector<framework::Variable*> PyLayer::Apply(
    int func_id, const std::vector<VarBase*>& inputs) {
X
Xin Pan 已提交
458
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
459
  return CallPythonFunc(py_funcs_[func_id], inputs);
X
Xin Pan 已提交
460 461
}

M
minqiyang 已提交
462 463
std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
                                         const std::vector<VarBase*>& inputs) {
X
polish  
Xin Pan 已提交
464
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
465 466 467 468 469 470 471 472 473 474 475 476
  auto rets = CallPythonFunc(py_funcs_[func_id], inputs);

  std::vector<VarBase*> outs;
  outs.reserve(rets.size());
  for (size_t i = 0U; i != rets.size(); ++i) {
    outs.emplace_back(new VarBase(
        string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
                        i),
        rets[i], nullptr, true));
  }

  return outs;
X
polish  
Xin Pan 已提交
477
}
X
Xin Pan 已提交
478

X
polish  
Xin Pan 已提交
479
std::vector<framework::Variable*> PyLayer::CallPythonFunc(
M
minqiyang 已提交
480
    const py::object& callable, const std::vector<VarBase*>& ins) {
X
polish  
Xin Pan 已提交
481 482 483
  py::gil_scoped_acquire guard;
  py::tuple in_args(ins.size());
  for (size_t i = 0; i < ins.size(); ++i) {
M
minqiyang 已提交
484
    const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
X
polish  
Xin Pan 已提交
485
    in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
X
Xin Pan 已提交
486
  }
X
polish  
Xin Pan 已提交
487 488 489 490 491 492 493
  VLOG(3) << "pyfunc in " << py::len(in_args);

  // TODO(panyx0718): Who owns the returned LoDTensor.
  auto ret = callable(in_args);
  auto ret_tuple = py::cast<py::tuple>(ret);
  size_t ret_num = py::len(ret_tuple);
  std::vector<framework::Variable*> outs;
M
minqiyang 已提交
494
  outs.reserve(ret_num);
X
polish  
Xin Pan 已提交
495 496 497 498 499 500 501 502 503 504
  VLOG(3) << "pyfunc out " << ret_num;
  for (size_t i = 0; i < ret_num; ++i) {
    try {
      auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
      PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
                              "Output tensor %d should not be nullptr", i);
      auto* var = new framework::Variable();
      auto* tensor = var->GetMutable<framework::LoDTensor>();
      tensor->ShareDataWith(*py_out_tensor);
      tensor->set_lod(py_out_tensor->lod());
M
minqiyang 已提交
505
      outs.emplace_back(var);
X
polish  
Xin Pan 已提交
506 507 508 509 510
    } catch (py::cast_error&) {
      PADDLE_THROW("The %d-th output must be LoDTensor", i);
    }
  }
  return outs;
X
Xin Pan 已提交
511 512
}

513 514
}  // namespace imperative
}  // namespace paddle