tensor_wrapper.h 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/**
 * We now still need TensorWrapper and it is designed to Copy
 * tensor in autograd mode.
 *
 * Since in autograd usage, we need to pass autograd_meta to
 * backward computation however in tensor interface add to much
 * autograd_related method is not a good choice.
 *
 * In TensorWrapper we will keep autograd info to backward, only
 * for input var, but for output var it will only copy autograd
 * with no grad **/

#pragma once
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/utils.h"

namespace egr {
class TensorWrapper {
 public:
  TensorWrapper() = default;
  explicit TensorWrapper(const egr::EagerTensor& tensor,
                         bool full_reserved = false) {
    /**
     * Normally, we should fully reserved all non-output or non-leaf fwd tensor
     * here. And for fwd output tensor, we should not reserve its autogradmeta,
     * to avoid recursive depends on GradNodeBase
     * **/
    full_reserved_ = full_reserved;
    if (full_reserved_) {
      VLOG(6) << "Fully reserved tensor: " << tensor.name();
      intermidiate_tensor_ = tensor;
      return;
    }

    // shallow copy tensor_impl here
    intermidiate_tensor_.set_impl(tensor.impl());
    intermidiate_tensor_.ResetVar(tensor.Var());
    intermidiate_tensor_.set_name(tensor.name() + "@Saved");
    PADDLE_ENFORCE_NOT_NULL(
        EagerUtils::unsafe_autograd_meta(tensor),
        paddle::platform::errors::Fatal(
            "Full reserved Tensor should not have null autograd meta, since "
            "tensor_wrapper is used to build backward info. There is no way "
            "for us to build it with null autograd_meta."));
    // copy output_rank
    out_rank_info_ = EagerUtils::OutRankInfo(tensor);
  }

  egr::EagerTensor recover(const std::shared_ptr<GradNodeBase>& grad_node) {
    VLOG(6) << "Recover tensor for wrapper";
    if ((!intermidiate_tensor_.defined()) &&
        (!intermidiate_tensor_.Var().IsInitialized())) {
      VLOG(6) << "Return NULL tensor Here. ";
      return egr::EagerTensor();
    }

    // if it's full_reserved just return the full copy of tensor
    if (full_reserved_) {
      return intermidiate_tensor_;
    } else {
      std::shared_ptr<GradNodeBase> new_grad_node = grad_node;
      auto p_ab_autograd_meta =
          std::make_shared<AutogradMeta>(Edge(new_grad_node, out_rank_info_));
      intermidiate_tensor_.set_autograd_meta(
          std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
              p_ab_autograd_meta));
      return intermidiate_tensor_;
    }
  }

 private:
  bool full_reserved_ = false;
  std::pair<size_t, size_t> out_rank_info_;
  egr::EagerTensor intermidiate_tensor_;
};
}  // namespace egr