tensor_wrapper.h 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/**
 * We now still need TensorWrapper and it is designed to Copy
 * tensor in autograd mode.
 *
 * Since in autograd usage, we need to pass autograd_meta to
 * backward computation however in tensor interface add to much
 * autograd_related method is not a good choice.
 *
 * In TensorWrapper we will keep autograd info to backward, only
 * for input var, but for output var it will only copy autograd
 * with no grad **/

#pragma once
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/utils.h"

namespace egr {
class TensorWrapper {
 public:
  TensorWrapper() = default;
36
  explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
37 38
                         bool full_reserved = false,
                         bool no_need_buffer = false) {
39 40 41 42 43 44 45 46 47
    // set inplace_version_snapshot_ according to tensor's current inplace
    // version.
    if (tensor.impl() && phi::DenseTensor::classof(tensor.impl().get())) {
      phi::DenseTensor* dense_tensor =
          static_cast<phi::DenseTensor*>(tensor.impl().get());
      auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();
      inplace_version_snapshot_ = inplace_version_counter.CurrentVersion();
    }

48 49 50 51 52 53 54 55 56 57 58 59 60
    /**
     * Normally, we should fully reserved all non-output or non-leaf fwd tensor
     * here. And for fwd output tensor, we should not reserve its autogradmeta,
     * to avoid recursive depends on GradNodeBase
     * **/
    full_reserved_ = full_reserved;
    if (full_reserved_) {
      VLOG(6) << "Fully reserved tensor: " << tensor.name();
      intermidiate_tensor_ = tensor;
      return;
    }

    // shallow copy tensor_impl here
61
    no_need_buffer_ = no_need_buffer;
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    if (no_need_buffer) {
      if (phi::DenseTensor::classof(tensor.impl().get())) {
        // Only Copy Meta
        phi::DenseTensor* dense_tensor =
            static_cast<phi::DenseTensor*>(tensor.impl().get());
        auto tw_dense_tensor = std::make_shared<phi::DenseTensor>();
        tw_dense_tensor->set_meta(dense_tensor->meta());
        intermidiate_tensor_.set_impl(tw_dense_tensor);
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unrecognized tensor type for no_need_buffer feature"));
      }
    } else {
      intermidiate_tensor_.set_impl(tensor.impl());
    }

78
    intermidiate_tensor_.set_name(tensor.name() + "@Saved");
79 80 81 82 83 84 85

    // If an output is marked "intermedaite", we won't create
    // autograd_meta for it.
    // In that case, simply skip OutRankInfo Copy
    if (EagerUtils::nullable_autograd_meta(tensor)) {
      out_rank_info_ = EagerUtils::OutRankInfo(tensor);
    }
86 87
  }

88 89 90 91 92
  paddle::experimental::Tensor recover(
      const std::shared_ptr<GradNodeBase>& grad_node) {
    VLOG(6) << "Recover tensor: " << intermidiate_tensor_.name()
            << " for wrapper";
    if (!intermidiate_tensor_.defined()) {
93
      VLOG(6) << "Return NULL tensor Here. ";
94
      return paddle::experimental::Tensor();
95 96 97 98
    }

    // if it's full_reserved just return the full copy of tensor
    if (full_reserved_) {
99
      check_inplace_version();
100 101 102 103 104 105 106 107
      return intermidiate_tensor_;
    } else {
      std::shared_ptr<GradNodeBase> new_grad_node = grad_node;
      auto p_ab_autograd_meta =
          std::make_shared<AutogradMeta>(Edge(new_grad_node, out_rank_info_));
      intermidiate_tensor_.set_autograd_meta(
          std::static_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
              p_ab_autograd_meta));
108
      check_inplace_version();
109 110 111 112
      return intermidiate_tensor_;
    }
  }

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  void check_inplace_version() {
    if (no_need_buffer_) {
      VLOG(6) << "There's no need to check inplace_version because "
                 "no_need_buffer_ is true.";
      return;
    }
    if (intermidiate_tensor_.impl() &&
        phi::DenseTensor::classof(intermidiate_tensor_.impl().get())) {
      phi::DenseTensor* dense_tensor =
          static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get());
      auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();

      uint32_t current_inplace_version =
          inplace_version_counter.CurrentVersion();
      PADDLE_ENFORCE_EQ(
          current_inplace_version, inplace_version_snapshot_,
          paddle::platform::errors::PermissionDenied(
              "Tensor '%s' used in gradient computation has been "
              "modified by an inplace operation. "
              "Its version is %d but the expected version is %d. "
              "Please fix your code to void calling an inplace operator "
              "after using the Tensor which will used in gradient "
              "computation.",
              intermidiate_tensor_.name(), current_inplace_version,
              inplace_version_snapshot_));
      VLOG(6) << " The inplace_version_snapshot_ of Tensor '"
              << intermidiate_tensor_.name() << "' is [ "
              << inplace_version_snapshot_ << " ]";
      VLOG(6) << " The current_inplace_version of Tensor '"
              << intermidiate_tensor_.name() << "' is [ "
              << current_inplace_version << " ]";
    }
  }

147 148
  void clear() { intermidiate_tensor_.reset(); }

149 150
 private:
  bool full_reserved_ = false;
151
  bool no_need_buffer_ = false;
152
  std::pair<size_t, size_t> out_rank_info_;
153
  paddle::experimental::Tensor intermidiate_tensor_;
154
  uint32_t inplace_version_snapshot_ = 0;
155 156
};
}  // namespace egr