layer.cc 17.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17 18 19 20
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
21
#include <unordered_set>
22 23 24 25
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
26
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
27 28 29
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
30 31 32 33 34
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

X
polish  
Xin Pan 已提交
35 36
const char* PyLayer::kFwdInp = "X";
const char* PyLayer::kFwdOut = "Out";
X
polish  
Xin Pan 已提交
37

X
Xin Pan 已提交
38 39
std::map<int, py::object> py_funcs_;

40 41
using framework::Variable;

M
minqiyang 已提交
42 43 44 45 46 47 48 49 50 51 52
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
53
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
54 55 56 57 58 59 60 61
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
62
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

P
Paddle CI 已提交
84
void AddTo(Variable* src, Variable* dst, platform::Place place) {
M
minqiyang 已提交
85 86 87
  framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();

M
minqiyang 已提交
88 89 90 91 92
  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (src_tensor->numel() == 0) {
    return;
  }
M
minqiyang 已提交
93

94 95 96
  PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                 "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                 src_tensor->numel());
M
minqiyang 已提交
97 98 99 100 101

  detail::TensorAddToFunctor<float> func(
      src_tensor->numel(), src_tensor->data<float>(),
      dst_tensor->mutable_data<float>(place));
  boost::apply_visitor(func, place);
102 103
}

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
  PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
                 "Can't find %s in backward grad map", target->Name());
  std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
      bck_map->at(target);
  std::sort(
      current.second.begin(), current.second.end(),
      [](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
        return a.first > b.first;
      });
  for (auto& var_pair : current.second) {
    Variable* origin_grad = target->var_;
    Variable* grad_to_add = var_pair.second->var_;
    VLOG(2) << "add origin_grad: " << target->Name();
    VLOG(2) << "added grad: " << var_pair.second->Name()
            << " trace id is: " << var_pair.first;
    AddTo(grad_to_add, origin_grad, current.first);
121
    delete var_pair.second;
122 123 124 125
    var_pair.second = nullptr;
  }
}

126 127
class Autograd {
 public:
X
Xin Pan 已提交
128
  Autograd() {}
129

130
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
131
    if (var->IsStopGradient()) {
132 133
      return;
    }
X
Xin Pan 已提交
134
    VLOG(3) << "start autograd";
135 136
    bck_map = new BackwardSumMap();
    grad_ref = new GradientRef();
137
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
138
    ready.push_back(var->PreOp());
139

140 141
    std::map<OpBase*, int> dep_counts =
        ComputeDepCounts(var->PreOp(), bck_stratedy);
142 143 144 145

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
X
Xin Pan 已提交
146
      std::map<std::string, std::vector<VarBase*>> input_grads =
147
          ready_op->ApplyGrad(bck_map, grad_ref, bck_stratedy);
X
Xin Pan 已提交
148

149 150
      for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
        const std::vector<VarBase*>& ingrads = it->second;
X
Xin Pan 已提交
151 152
        for (size_t i = 0; i < ingrads.size(); ++i) {
          if (!ingrads[i]) continue;
153
          if (ready_op->input_vars_[it->first][i]->IsStopGradient()) {
154 155
            continue;
          }
156
          OpBase* pre_op = ready_op->pre_ops_[it->first][i];
X
Xin Pan 已提交
157 158 159 160 161 162 163 164
          if (!pre_op) continue;

          dep_counts[pre_op] -= 1;
          PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
          bool pre_op_ready = dep_counts[pre_op] == 0;
          if (pre_op_ready) {
            ready.push_back(pre_op);
          }
165 166
        }
      }
167 168

      ready_op->InvokeBackwardHooks();
169 170 171 172
    }
  }

 private:
173 174
  std::map<OpBase*, int> ComputeDepCounts(
      OpBase* op, const detail::BackwardStrategy& bck_stratedy) {
175 176 177 178 179 180 181 182 183
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
      if (bck_stratedy.sorted_sum_gradient_) {
        for (const auto& map : candidate->grad_output_vars_) {
          for (const auto& it : map) {
            for (const auto& vb : it.second) {
              if (grad_ref->find(vb) == grad_ref->end()) {
                grad_ref->insert(std::make_pair(vb, 1));
              } else {
                // add ref count by 1 when we find grad_var can be generated by
                // one grad_op
                grad_ref->at(vb) += 1;
              }
            }
          }
        }
      }
X
Xin Pan 已提交
199
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
200 201
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
202
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
203
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
204
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
205 206 207 208 209
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
210 211 212 213 214
        }
      }
    }
    return ret;
  }
215 216 217

  BackwardSumMap* bck_map;
  GradientRef* grad_ref;
218 219
};

M
minqiyang 已提交
220 221
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
222 223 224
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

225 226 227 228
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
229 230 231
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
232

233 234
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
235
  if (blocking) {
236 237 238 239 240
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
241 242 243
  }

  if (platform::is_gpu_place(dst_place)) {
244
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
245 246
  }

P
Paddle CI 已提交
247
  return new_var;
M
minqiyang 已提交
248 249
}

M
minqiyang 已提交
250
framework::LoDTensor& VarBase::GradValue() {
251 252 253
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
254
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
255 256
}

257 258 259
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
260 261
  PADDLE_ENFORCE(!grad_op_descs_.empty() || backward_id_ > 0,
                 "%s has no backward implementation", Type());
262

263
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
264
  std::vector<VarBasePtrMap> tmp_grad_outputs;
X
Xin Pan 已提交
265 266
  if (backward_id_ > 0) {
    VLOG(3) << "py_layer_grad";
267 268
    tmp_grad_outputs.resize(1);
    tmp_grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
X
Xin Pan 已提交
269 270 271
        PyLayer::ApplyGrad(
            backward_id_,
            grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
X
Xin Pan 已提交
272
  } else {
273 274 275 276
    const size_t grad_op_count = grad_op_descs_.size();

    tmp_grad_outputs.resize(grad_op_count);
    for (size_t k = 0; k < grad_op_count; ++k) {
X
Xin Pan 已提交
277
      framework::OpDesc* grad_op_desc = grad_op_descs_[k];
278 279 280 281 282
      auto& grad_output_variable_map = grad_output_vars_[k];

      VLOG(3) << "apply grad op " << grad_op_desc->Type();

      // Allocate tmp grad output variable
M
minqiyang 已提交
283
      for (const auto& it : grad_output_variable_map) {
284 285
        auto& outputs = tmp_grad_outputs[k][it.first];
        outputs.reserve(it.second.size());
X
Xin Pan 已提交
286
        for (size_t i = 0; i < it.second.size(); ++i) {
M
minqiyang 已提交
287 288
          VarBase* origin_grad_var_base = it.second[i];

X
Xin Pan 已提交
289
          // Allocate a new variable
M
minqiyang 已提交
290 291 292 293 294
          VarBase* tmp_grad_var_base = new VarBase(
              string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
              origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
              place_, true, false);
          outputs.emplace_back(tmp_grad_var_base);
X
Xin Pan 已提交
295
        }
X
polish  
Xin Pan 已提交
296
      }
297

X
Xin Pan 已提交
298 299
      // No need to do compile time infer shape here.
      // grad_op_desc_->InferShape(*block_);
300
      // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
301

X
Xin Pan 已提交
302 303
      std::unique_ptr<framework::OperatorBase> opbase =
          framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
304

M
minqiyang 已提交
305 306 307 308 309 310
      auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
      if (info.infer_var_type_) {
        RuntimeInferVarTypeContext infer_var_type_ctx(
            &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
        info.infer_var_type_(&infer_var_type_ctx);
      }
M
minqiyang 已提交
311

X
Xin Pan 已提交
312 313 314
      framework::OperatorWithKernel* op_kernel =
          dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
      PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
315

M
minqiyang 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
      // Run grad op
      framework::VariableValueMap grad_invars_map;
      framework::VariableValueMap grad_outvars_map;

      for (const auto& it : grad_input_vars_[k]) {
        auto& grad_invars = grad_invars_map[it.first];
        grad_invars.reserve(it.second.size());
        for (const VarBase* grad_inp : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                  grad_op_desc->Type(), grad_inp->Name());

          grad_invars.emplace_back(grad_inp->var_);
        }
      }

      for (const auto& it : tmp_grad_outputs[k]) {
        auto& grad_outvars = grad_outvars_map[it.first];
        grad_outvars.reserve(it.second.size());
        for (VarBase* grad_out : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                  grad_op_desc->Type(), grad_out->Name());

          grad_outvars.emplace_back(grad_out->var_);
        }
      }

      framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
X
Xin Pan 已提交
343 344 345
      framework::Scope scope;
      PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
      p.op.RuntimeInferShape(scope, place_, ctx);
346 347
      p.func(
          framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
348
    }
X
Xin Pan 已提交
349
  }
X
Xin Pan 已提交
350

351
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
352
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
353
    for (const auto& it : grad_output_vars_[k]) {
354
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
355
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
356 357 358
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA
          VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " ";
          VLOG(2) << origin_outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
          VLOG(2) << "outputs is : " << outputs[i]->Name() << " ";
          VLOG(2) << outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
#endif
          if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
            VLOG(2) << "add sub grad to " << origin_outputs[i]->Name();
            bck_map->at(origin_outputs[i])
                .second.emplace_back(
                    std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
          } else {
            VLOG(2) << "insert new map for " << origin_outputs[i]->Name();
            std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
                tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
            bck_map->insert(std::make_pair(origin_outputs[i], tmp));
          }

          PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
                         "Can't find  %s in grad_reference count map",
                         origin_outputs[i]->Name());
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
                         "Backward error when calculate grad reference");
          if (grad_ref->at(origin_outputs[i]) > 1) {
            VLOG(2) << "remove ref for " << origin_outputs[i]->Name();
            grad_ref->at(origin_outputs[i])--;
          } else {
            VLOG(2) << "Add grad for: " << origin_outputs[i]->Name();
            AddGradBySort(bck_map, origin_outputs[i]);
            grad_ref->at(origin_outputs[i])--;
          }
        } else {
          framework::Variable* grad = outputs[i]->var_;
          framework::Variable* orig_grad = origin_outputs[i]->var_;
          VLOG(2) << "AddTo Called with orig_grad is: "
                  << origin_outputs[i]->name_ << " Grad to be added is "
                  << outputs[i]->name_;
          AddTo(grad, orig_grad, place_);
403
          delete outputs[i];
404
        }
X
Xin Pan 已提交
405
      }
406 407
    }
  }
X
Xin Pan 已提交
408

X
Xin Pan 已提交
409
  return input_vars_;
410 411
}

412
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
413
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
414 415 416 417 418 419 420

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

421
void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) {
M
minqiyang 已提交
422
  VLOG(3) << "Register backward hooks " << trace_id_;
423 424

  // TODO(minqiyang): check the callable format
425 426 427 428 429
  if (front) {
    backward_hooks_.insert(backward_hooks_.begin(), callable);
  } else {
    backward_hooks_.push_back(callable);
  }
430 431
}

432
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
433
  if (!pre_op_) return;
X
Xin Pan 已提交
434

X
Xin Pan 已提交
435
  VLOG(3) << "start backward";
M
minqiyang 已提交
436
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
437 438 439 440
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
441

X
Xin Pan 已提交
442 443 444
  PADDLE_ENFORCE(
      grads_ ==
      pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
445
  Autograd().RunBackward(this, bck_stratedy);
446 447
}

X
Xin Pan 已提交
448 449 450 451
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
  py_funcs_[func_id] = py_func;
}

X
polish  
Xin Pan 已提交
452 453
int PyLayer::NumFuncs() { return py_funcs_.size(); }

M
minqiyang 已提交
454 455
std::vector<framework::Variable*> PyLayer::Apply(
    int func_id, const std::vector<VarBase*>& inputs) {
X
Xin Pan 已提交
456
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
457
  return CallPythonFunc(py_funcs_[func_id], inputs);
X
Xin Pan 已提交
458 459
}

M
minqiyang 已提交
460 461
std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
                                         const std::vector<VarBase*>& inputs) {
X
polish  
Xin Pan 已提交
462
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
463 464 465 466 467 468 469 470 471 472 473 474
  auto rets = CallPythonFunc(py_funcs_[func_id], inputs);

  std::vector<VarBase*> outs;
  outs.reserve(rets.size());
  for (size_t i = 0U; i != rets.size(); ++i) {
    outs.emplace_back(new VarBase(
        string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
                        i),
        rets[i], nullptr, true));
  }

  return outs;
X
polish  
Xin Pan 已提交
475
}
X
Xin Pan 已提交
476

X
polish  
Xin Pan 已提交
477
std::vector<framework::Variable*> PyLayer::CallPythonFunc(
M
minqiyang 已提交
478
    const py::object& callable, const std::vector<VarBase*>& ins) {
X
polish  
Xin Pan 已提交
479 480 481
  py::gil_scoped_acquire guard;
  py::tuple in_args(ins.size());
  for (size_t i = 0; i < ins.size(); ++i) {
M
minqiyang 已提交
482
    const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
X
polish  
Xin Pan 已提交
483
    in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
X
Xin Pan 已提交
484
  }
X
polish  
Xin Pan 已提交
485 486 487 488 489 490 491
  VLOG(3) << "pyfunc in " << py::len(in_args);

  // TODO(panyx0718): Who owns the returned LoDTensor.
  auto ret = callable(in_args);
  auto ret_tuple = py::cast<py::tuple>(ret);
  size_t ret_num = py::len(ret_tuple);
  std::vector<framework::Variable*> outs;
M
minqiyang 已提交
492
  outs.reserve(ret_num);
X
polish  
Xin Pan 已提交
493 494 495 496 497 498 499 500 501 502
  VLOG(3) << "pyfunc out " << ret_num;
  for (size_t i = 0; i < ret_num; ++i) {
    try {
      auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
      PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
                              "Output tensor %d should not be nullptr", i);
      auto* var = new framework::Variable();
      auto* tensor = var->GetMutable<framework::LoDTensor>();
      tensor->ShareDataWith(*py_out_tensor);
      tensor->set_lod(py_out_tensor->lod());
M
minqiyang 已提交
503
      outs.emplace_back(var);
X
polish  
Xin Pan 已提交
504 505 506 507 508
    } catch (py::cast_error&) {
      PADDLE_THROW("The %d-th output must be LoDTensor", i);
    }
  }
  return outs;
X
Xin Pan 已提交
509 510
}

511 512
}  // namespace imperative
}  // namespace paddle