layer.cc 18.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
M
minqiyang 已提交
16

17 18 19 20
#include <deque>
#include <limits>
#include <map>
#include <random>
M
minqiyang 已提交
21
#include <unordered_set>
22 23 24 25
#include <utility>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
26
#include "paddle/fluid/framework/operator.h"
M
minqiyang 已提交
27 28 29
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
C
chengduo 已提交
30
#include "paddle/fluid/platform/profiler.h"
31 32 33 34 35
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace imperative {

X
polish  
Xin Pan 已提交
36 37
const char* PyLayer::kFwdInp = "X";
const char* PyLayer::kFwdOut = "Out";
X
polish  
Xin Pan 已提交
38

X
Xin Pan 已提交
39 40
std::map<int, py::object> py_funcs_;

41 42
using framework::Variable;

M
minqiyang 已提交
43 44 45 46 47 48 49 50 51 52 53
namespace detail {

template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
 public:
  TensorAddToFunctor(int64_t numel, const T* x, T* y)
      : numel_(numel), x_(x), y_(y) {}

  void operator()(const platform::CPUPlace& place) {
    platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
54
    auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
M
minqiyang 已提交
55 56 57 58 59 60 61 62
    blas.AXPY(numel_, 1., x_, y_);
  }

#ifdef PADDLE_WITH_CUDA
  void operator()(const platform::CUDAPlace& place) {
    platform::CUDADeviceContext* ctx =
        dynamic_cast<platform::CUDADeviceContext*>(
            platform::DeviceContextPool::Instance().Get(place));
P
Paddle CI 已提交
63
    auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
M
minqiyang 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    blas.AXPY(numel_, 1., x_, y_);
  }
#else
  void operator()(const platform::CUDAPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }
#endif

  // there is NO blas in CUDAPinnedPlace
  void operator()(const platform::CUDAPinnedPlace& place) {
    PADDLE_THROW("Do NOT support gradient merge in place %s", place);
  }

 private:
  int64_t numel_;
  const T* x_;
  T* y_;
};

}  // namespace detail

P
Paddle CI 已提交
85
void AddTo(Variable* src, Variable* dst, platform::Place place) {
M
minqiyang 已提交
86 87 88
  framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
  framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();

M
minqiyang 已提交
89 90 91 92 93
  // FIXME(minqiyang): loss_grad op will pass a zero grad of label
  // ugly fix for it
  if (src_tensor->numel() == 0) {
    return;
  }
M
minqiyang 已提交
94

95 96 97
  PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
                 "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
                 src_tensor->numel());
M
minqiyang 已提交
98 99 100 101 102

  detail::TensorAddToFunctor<float> func(
      src_tensor->numel(), src_tensor->data<float>(),
      dst_tensor->mutable_data<float>(place));
  boost::apply_visitor(func, place);
103 104
}

105 106 107 108 109 110 111 112 113 114 115
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
  PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
                 "Can't find %s in backward grad map", target->Name());
  std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
      bck_map->at(target);
  std::sort(
      current.second.begin(), current.second.end(),
      [](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
        return a.first > b.first;
      });
  for (auto& var_pair : current.second) {
116 117
    Variable* origin_grad = target->var_.get();
    Variable* grad_to_add = var_pair.second->var_.get();
118 119 120 121
    VLOG(2) << "add origin_grad: " << target->Name();
    VLOG(2) << "added grad: " << var_pair.second->Name()
            << " trace id is: " << var_pair.first;
    AddTo(grad_to_add, origin_grad, current.first);
122
    delete var_pair.second;
123 124 125 126
    var_pair.second = nullptr;
  }
}

127 128
class Autograd {
 public:
X
Xin Pan 已提交
129
  Autograd() {}
130

131
  void RunBackward(VarBase* var, const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
132
    if (var->IsStopGradient()) {
133 134
      return;
    }
X
Xin Pan 已提交
135
    VLOG(3) << "start autograd";
136 137
    BackwardSumMap bck_map;
    GradientRef grad_ref;
138
    std::deque<OpBase*> ready;
X
Xin Pan 已提交
139
    ready.push_back(var->PreOp());
140

141
    std::map<OpBase*, int> dep_counts =
142
        ComputeDepCounts(var->PreOp(), bck_stratedy, &grad_ref);
143 144 145 146

    while (!ready.empty()) {
      OpBase* ready_op = ready.front();
      ready.pop_front();
X
Xin Pan 已提交
147
      std::map<std::string, std::vector<VarBase*>> input_grads =
148
          ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
X
Xin Pan 已提交
149

150 151
      for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
        const std::vector<VarBase*>& ingrads = it->second;
X
Xin Pan 已提交
152 153
        for (size_t i = 0; i < ingrads.size(); ++i) {
          if (!ingrads[i]) continue;
Y
Yan Xu 已提交
154 155 156
          auto p = ready_op->input_vars_[it->first][i];

          if (p->IsStopGradient()) continue;
157
          OpBase* pre_op = ready_op->pre_ops_[it->first][i];
X
Xin Pan 已提交
158 159 160 161 162 163 164 165
          if (!pre_op) continue;

          dep_counts[pre_op] -= 1;
          PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
          bool pre_op_ready = dep_counts[pre_op] == 0;
          if (pre_op_ready) {
            ready.push_back(pre_op);
          }
166 167
        }
      }
168 169

      ready_op->InvokeBackwardHooks();
170 171 172 173
    }
  }

 private:
174
  std::map<OpBase*, int> ComputeDepCounts(
175 176 177 178 179 180 181
      OpBase* op, const detail::BackwardStrategy& bck_stratedy,
      GradientRef* grad_ref) {
    if (bck_stratedy.sorted_sum_gradient_) {
      PADDLE_ENFORCE_NOT_NULL(grad_ref,
                              "grad_ref should not be null when "
                              "using sorted grad backward strategy");
    }
182 183 184 185 186 187 188 189 190
    std::map<OpBase*, int> ret;

    std::deque<OpBase*> queue;
    queue.push_back(op);
    std::unordered_set<OpBase*> visited;
    visited.insert(op);
    while (!queue.empty()) {
      OpBase* candidate = queue.front();
      queue.pop_front();
191 192 193 194
      if (bck_stratedy.sorted_sum_gradient_) {
        for (const auto& map : candidate->grad_output_vars_) {
          for (const auto& it : map) {
            for (const auto& vb : it.second) {
195
              ++(*grad_ref)[vb];
196 197 198 199
            }
          }
        }
      }
X
Xin Pan 已提交
200
      for (auto it : candidate->pre_ops_) {
X
Xin Pan 已提交
201 202
        for (OpBase* pre_op : it.second) {
          if (!pre_op) continue;
203
          VLOG(2) << "op dep " << candidate->Type() << " trace id "
204
                  << candidate->trace_id_ << " <---- " << it.first << " <---- "
205
                  << pre_op->Type() << " trace id " << pre_op->trace_id_;
X
Xin Pan 已提交
206 207 208 209 210
          if (visited.find(pre_op) == visited.end()) {
            visited.insert(pre_op);
            queue.push_back(pre_op);
          }
          ret[pre_op] += 1;
211 212 213 214 215 216 217
        }
      }
    }
    return ret;
  }
};

M
minqiyang 已提交
218 219
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
                                             const bool blocking) const {
M
minqiyang 已提交
220 221 222
  PADDLE_ENFORCE(var_->IsInitialized(),
                 "Variable must be initialized when getting numpy tensor");

223 224 225 226
  // TODO(minqiyang): change this after move unique_name generator to CXX
  const framework::LoDTensor& self_tensor = var_->Get<framework::LoDTensor>();
  std::unique_ptr<VarBase> new_var(new VarBase(
      "Itmp", self_tensor.type(), self_tensor.dims(), dst_place, true, false));
P
Paddle CI 已提交
227 228 229
  framework::LoDTensor* tensor =
      new_var->var_->GetMutable<framework::LoDTensor>();
  tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
M
minqiyang 已提交
230

231 232
  const auto& src_tensor = var_->Get<framework::LoDTensor>();
  framework::TensorCopy(src_tensor, dst_place, tensor);
P
Paddle CI 已提交
233
  if (blocking) {
234 235 236 237 238
    platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
    auto src_place = src_tensor.place();
    if (!(src_place == dst_place)) {
      platform::DeviceContextPool::Instance().Get(src_place)->Wait();
    }
P
Paddle CI 已提交
239 240 241
  }

  if (platform::is_gpu_place(dst_place)) {
242
    VLOG(3) << "copy tensor " << Name() << " from gpu";
M
minqiyang 已提交
243 244
  }

P
Paddle CI 已提交
245
  return new_var;
M
minqiyang 已提交
246 247
}

M
minqiyang 已提交
248
framework::LoDTensor& VarBase::GradValue() {
249 250 251
  VLOG(3) << "get var grad " << Name();
  PADDLE_ENFORCE_NOT_NULL(grads_,
                          "Could not get grad value from no grad variable");
M
minqiyang 已提交
252
  return *(grads_->var_->GetMutable<framework::LoDTensor>());
253 254
}

255 256 257
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
    BackwardSumMap* bck_map, GradientRef* grad_ref,
    const detail::BackwardStrategy& bck_stratedy) {
X
Xin Pan 已提交
258 259
  PADDLE_ENFORCE(!grad_op_descs_.empty() || backward_id_ > 0,
                 "%s has no backward implementation", Type());
260
  VLOG(3) << "apply op grad: " << Type();
M
minqiyang 已提交
261
  std::vector<VarBasePtrMap> tmp_grad_outputs;
X
Xin Pan 已提交
262 263
  if (backward_id_ > 0) {
    VLOG(3) << "py_layer_grad";
264 265
    tmp_grad_outputs.resize(1);
    tmp_grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
X
Xin Pan 已提交
266 267 268
        PyLayer::ApplyGrad(
            backward_id_,
            grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
X
Xin Pan 已提交
269
  } else {
270 271 272 273
    const size_t grad_op_count = grad_op_descs_.size();

    tmp_grad_outputs.resize(grad_op_count);
    for (size_t k = 0; k < grad_op_count; ++k) {
X
Xin Pan 已提交
274
      framework::OpDesc* grad_op_desc = grad_op_descs_[k];
C
chengduo 已提交
275
      platform::RecordEvent record_event(grad_op_desc->Type());
276 277 278 279
      auto& grad_output_variable_map = grad_output_vars_[k];
      VLOG(3) << "apply grad op " << grad_op_desc->Type();

      // Allocate tmp grad output variable
M
minqiyang 已提交
280
      for (const auto& it : grad_output_variable_map) {
281 282
        auto& outputs = tmp_grad_outputs[k][it.first];
        outputs.reserve(it.second.size());
X
Xin Pan 已提交
283
        for (size_t i = 0; i < it.second.size(); ++i) {
M
minqiyang 已提交
284 285
          VarBase* origin_grad_var_base = it.second[i];

X
Xin Pan 已提交
286
          // Allocate a new variable
M
minqiyang 已提交
287 288 289 290 291
          VarBase* tmp_grad_var_base = new VarBase(
              string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
              origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
              place_, true, false);
          outputs.emplace_back(tmp_grad_var_base);
X
Xin Pan 已提交
292
        }
X
polish  
Xin Pan 已提交
293
      }
294

X
Xin Pan 已提交
295 296
      // No need to do compile time infer shape here.
      // grad_op_desc_->InferShape(*block_);
297
      // grad_op_desc->InferVarType(block_);
X
Xin Pan 已提交
298

X
Xin Pan 已提交
299 300
      std::unique_ptr<framework::OperatorBase> opbase =
          framework::OpRegistry::CreateOp(*grad_op_desc);
M
minqiyang 已提交
301

M
minqiyang 已提交
302 303 304 305 306 307
      auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
      if (info.infer_var_type_) {
        RuntimeInferVarTypeContext infer_var_type_ctx(
            &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
        info.infer_var_type_(&infer_var_type_ctx);
      }
M
minqiyang 已提交
308

X
Xin Pan 已提交
309 310 311
      framework::OperatorWithKernel* op_kernel =
          dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
      PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
X
Xin Pan 已提交
312

M
minqiyang 已提交
313 314 315 316 317 318 319 320 321 322 323
      // Run grad op
      framework::VariableValueMap grad_invars_map;
      framework::VariableValueMap grad_outvars_map;

      for (const auto& it : grad_input_vars_[k]) {
        auto& grad_invars = grad_invars_map[it.first];
        grad_invars.reserve(it.second.size());
        for (const VarBase* grad_inp : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
                                  grad_op_desc->Type(), grad_inp->Name());

324
          grad_invars.emplace_back(grad_inp->var_.get());
M
minqiyang 已提交
325 326 327 328 329 330 331 332 333 334
        }
      }

      for (const auto& it : tmp_grad_outputs[k]) {
        auto& grad_outvars = grad_outvars_map[it.first];
        grad_outvars.reserve(it.second.size());
        for (VarBase* grad_out : it.second) {
          PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
                                  grad_op_desc->Type(), grad_out->Name());

335
          grad_outvars.emplace_back(grad_out->var_.get());
M
minqiyang 已提交
336 337 338 339
        }
      }

      framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
X
Xin Pan 已提交
340 341 342
      framework::Scope scope;
      PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
      p.op.RuntimeInferShape(scope, place_, ctx);
343 344
      p.func(
          framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr));
X
Xin Pan 已提交
345
    }
X
Xin Pan 已提交
346
  }
X
Xin Pan 已提交
347

C
chengduo 已提交
348
  platform::RecordEvent record_event("merge_grads");
349
  // Add tmp grad outputs to original grad vars
X
Xin Pan 已提交
350
  for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
M
minqiyang 已提交
351
    for (const auto& it : grad_output_vars_[k]) {
352
      auto& outputs = tmp_grad_outputs[k][it.first];
M
minqiyang 已提交
353
      const auto& origin_outputs = it.second;
X
Xin Pan 已提交
354 355 356
      PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());

      for (size_t i = 0; i < outputs.size(); ++i) {
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
        // track outputs used by sum
        if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA
          VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " ";
          VLOG(2) << origin_outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
          VLOG(2) << "outputs is : " << outputs[i]->Name() << " ";
          VLOG(2) << outputs[i]
                         ->var_->GetMutable<framework::LoDTensor>()
                         ->data<float>()[0];
#endif
          if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
            VLOG(2) << "add sub grad to " << origin_outputs[i]->Name();
            bck_map->at(origin_outputs[i])
                .second.emplace_back(
                    std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
          } else {
            VLOG(2) << "insert new map for " << origin_outputs[i]->Name();
            std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
                tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
            bck_map->insert(std::make_pair(origin_outputs[i], tmp));
          }

          PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
                         "Can't find  %s in grad_reference count map",
                         origin_outputs[i]->Name());
          PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
                         "Backward error when calculate grad reference");
          if (grad_ref->at(origin_outputs[i]) > 1) {
            VLOG(2) << "remove ref for " << origin_outputs[i]->Name();
            grad_ref->at(origin_outputs[i])--;
          } else {
            VLOG(2) << "Add grad for: " << origin_outputs[i]->Name();
            AddGradBySort(bck_map, origin_outputs[i]);
            grad_ref->at(origin_outputs[i])--;
          }
        } else {
395 396
          framework::Variable* grad = outputs[i]->var_.get();
          framework::Variable* orig_grad = origin_outputs[i]->var_.get();
397 398 399 400
          VLOG(2) << "AddTo Called with orig_grad is: "
                  << origin_outputs[i]->name_ << " Grad to be added is "
                  << outputs[i]->name_;
          AddTo(grad, orig_grad, place_);
401
          delete outputs[i];
402
        }
X
Xin Pan 已提交
403
      }
404 405
    }
  }
X
Xin Pan 已提交
406

X
Xin Pan 已提交
407
  return input_vars_;
408 409
}

410
void OpBase::InvokeBackwardHooks() {
M
minqiyang 已提交
411
  VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
412 413 414 415 416 417 418

  // call backward hooks
  for (py::object& callable : backward_hooks_) {
    callable(this);
  }
}

Y
Yan Xu 已提交
419
void OpBase::RegisterBackwardHooks(const py::object& callable) {
M
minqiyang 已提交
420
  VLOG(3) << "Register backward hooks " << trace_id_;
421 422

  // TODO(minqiyang): check the callable format
Y
Yan Xu 已提交
423
  backward_hooks_.push_back(callable);
424 425
}

426
void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
427
  if (!pre_op_) return;
C
chengduo 已提交
428
  platform::RecordEvent record_event("Imperative Backward");
X
Xin Pan 已提交
429
  VLOG(3) << "start backward";
M
minqiyang 已提交
430
  auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
M
minqiyang 已提交
431 432 433 434
  operators::math::set_constant(
      *(platform::DeviceContextPool::Instance().Get(
          var_->GetMutable<framework::LoDTensor>()->place())),
      grads_t, 1.0);
X
Xin Pan 已提交
435

X
Xin Pan 已提交
436 437 438
  PADDLE_ENFORCE(
      grads_ ==
      pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
439
  Autograd().RunBackward(this, bck_stratedy);
440 441
}

X
Xin Pan 已提交
442 443 444 445
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
  py_funcs_[func_id] = py_func;
}

X
polish  
Xin Pan 已提交
446 447
int PyLayer::NumFuncs() { return py_funcs_.size(); }

448
std::vector<std::unique_ptr<framework::Variable>> PyLayer::Apply(
M
minqiyang 已提交
449
    int func_id, const std::vector<VarBase*>& inputs) {
X
Xin Pan 已提交
450
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
451
  return CallPythonFunc(py_funcs_[func_id], inputs);
X
Xin Pan 已提交
452 453
}

M
minqiyang 已提交
454 455
std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
                                         const std::vector<VarBase*>& inputs) {
X
polish  
Xin Pan 已提交
456
  PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
M
minqiyang 已提交
457 458 459 460 461 462 463 464
  auto rets = CallPythonFunc(py_funcs_[func_id], inputs);

  std::vector<VarBase*> outs;
  outs.reserve(rets.size());
  for (size_t i = 0U; i != rets.size(); ++i) {
    outs.emplace_back(new VarBase(
        string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
                        i),
465
        std::move(rets[i]), nullptr, true));
M
minqiyang 已提交
466 467 468
  }

  return outs;
X
polish  
Xin Pan 已提交
469
}
X
Xin Pan 已提交
470

471
std::vector<std::unique_ptr<framework::Variable>> PyLayer::CallPythonFunc(
M
minqiyang 已提交
472
    const py::object& callable, const std::vector<VarBase*>& ins) {
X
polish  
Xin Pan 已提交
473 474 475
  py::gil_scoped_acquire guard;
  py::tuple in_args(ins.size());
  for (size_t i = 0; i < ins.size(); ++i) {
M
minqiyang 已提交
476
    const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
X
polish  
Xin Pan 已提交
477
    in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
X
Xin Pan 已提交
478
  }
X
polish  
Xin Pan 已提交
479 480 481 482 483 484
  VLOG(3) << "pyfunc in " << py::len(in_args);

  // TODO(panyx0718): Who owns the returned LoDTensor.
  auto ret = callable(in_args);
  auto ret_tuple = py::cast<py::tuple>(ret);
  size_t ret_num = py::len(ret_tuple);
485
  std::vector<std::unique_ptr<framework::Variable>> outs;
M
minqiyang 已提交
486
  outs.reserve(ret_num);
X
polish  
Xin Pan 已提交
487 488 489 490 491 492
  VLOG(3) << "pyfunc out " << ret_num;
  for (size_t i = 0; i < ret_num; ++i) {
    try {
      auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
      PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
                              "Output tensor %d should not be nullptr", i);
493 494
      auto var =
          std::unique_ptr<framework::Variable>(new framework::Variable());
X
polish  
Xin Pan 已提交
495 496 497
      auto* tensor = var->GetMutable<framework::LoDTensor>();
      tensor->ShareDataWith(*py_out_tensor);
      tensor->set_lod(py_out_tensor->lod());
498
      outs.emplace_back(std::move(var));
X
polish  
Xin Pan 已提交
499 500 501 502 503
    } catch (py::cast_error&) {
      PADDLE_THROW("The %d-th output must be LoDTensor", i);
    }
  }
  return outs;
X
Xin Pan 已提交
504 505
}

506 507
}  // namespace imperative
}  // namespace paddle