tensor_wrapper.h 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/**
 * We now still need TensorWrapper and it is designed to Copy
 * tensor in autograd mode.
 *
 * Since in autograd usage, we need to pass autograd_meta to
 * backward computation however in tensor interface add to much
 * autograd_related method is not a good choice.
 *
 * In TensorWrapper we will keep autograd info to backward, only
 * for input var, but for output var it will only copy autograd
 * with no grad **/

#pragma once
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/utils.h"

namespace egr {
class TensorWrapper {
 public:
  TensorWrapper() = default;
36
  explicit TensorWrapper(const paddle::experimental::Tensor& tensor,
37 38
                         bool full_reserved = false,
                         bool no_need_buffer = false) {
39 40 41 42 43 44 45 46 47
    // set inplace_version_snapshot_ according to tensor's current inplace
    // version.
    if (tensor.impl() && phi::DenseTensor::classof(tensor.impl().get())) {
      phi::DenseTensor* dense_tensor =
          static_cast<phi::DenseTensor*>(tensor.impl().get());
      auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();
      inplace_version_snapshot_ = inplace_version_counter.CurrentVersion();
    }

48 49 50 51 52 53
    /**
     * Normally, we should fully reserved all non-output or non-leaf fwd tensor
     * here. And for fwd output tensor, we should not reserve its autogradmeta,
     * to avoid recursive depends on GradNodeBase
     * **/
    full_reserved_ = full_reserved;
54
    no_need_buffer_ = no_need_buffer;
55 56 57
    if (full_reserved_) {
      VLOG(6) << "Fully reserved tensor: " << tensor.name();
      intermidiate_tensor_ = tensor;
58 59 60 61 62 63 64 65 66 67 68 69 70 71
      if (no_need_buffer_) {
        if (phi::DenseTensor::classof(tensor.impl().get())) {
          // Only Copy Meta
          phi::DenseTensor* dense_tensor =
              static_cast<phi::DenseTensor*>(tensor.impl().get());
          auto tw_dense_tensor =
              std::make_shared<phi::DenseTensor>(*dense_tensor);
          tw_dense_tensor->clear();
          intermidiate_tensor_.set_impl(tw_dense_tensor);
        } else {
          PADDLE_THROW(paddle::platform::errors::Fatal(
              "Unrecognized tensor type for no_need_buffer feature"));
        }
      }
72 73 74 75
      return;
    }

    // shallow copy tensor_impl here
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    if (no_need_buffer) {
      if (phi::DenseTensor::classof(tensor.impl().get())) {
        // Only Copy Meta
        phi::DenseTensor* dense_tensor =
            static_cast<phi::DenseTensor*>(tensor.impl().get());
        auto tw_dense_tensor = std::make_shared<phi::DenseTensor>();
        tw_dense_tensor->set_meta(dense_tensor->meta());
        intermidiate_tensor_.set_impl(tw_dense_tensor);
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unrecognized tensor type for no_need_buffer feature"));
      }
    } else {
      intermidiate_tensor_.set_impl(tensor.impl());
    }

92
    intermidiate_tensor_.set_name(tensor.name() + "@Saved");
93

94 95
    auto* tensor_autograd_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (tensor_autograd_meta) {
96 97 98
      auto autograd_meta =
          std::make_shared<AutogradMeta>(*tensor_autograd_meta);
      autograd_meta->ResetGradNode();
99 100
      intermidiate_tensor_.set_autograd_meta(autograd_meta);
      weak_grad_node_ = tensor_autograd_meta->GetMutableGradNode();
101
    }
102 103
  }

104
  paddle::experimental::Tensor recover() {
105 106 107
    VLOG(6) << "Recover tensor: " << intermidiate_tensor_.name()
            << " for wrapper";
    if (!intermidiate_tensor_.defined()) {
108
      VLOG(6) << "Return NULL tensor Here. ";
109
      return paddle::experimental::Tensor();
110 111
    }

112
    check_inplace_version();
113

114
    // if it's full_reserved just return the full copy of tensor
115 116 117 118 119
    if (full_reserved_) {
      return intermidiate_tensor_;
    } else {
      paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_;

120 121 122 123 124 125 126 127 128
      std::shared_ptr<GradNodeBase> new_grad_node = weak_grad_node_.lock();
      if (new_grad_node) {
        VLOG(3) << "Recovered TensorWrapper with GradNode "
                << new_grad_node->name() << " addr: " << new_grad_node.get();
      } else {
        VLOG(3) << "Recovered TensorWrapper with Empth GradNode";
      }
      auto* intermediate_autograd_meta =
          EagerUtils::unsafe_autograd_meta(intermidiate_tensor_);
129 130 131 132 133 134 135
      auto p_ab_autograd_meta =
          std::make_shared<AutogradMeta>(*intermediate_autograd_meta);
      if (new_grad_node) {
        p_ab_autograd_meta->SetGradNode(new_grad_node);
      }
      recovered_tensor.set_autograd_meta(p_ab_autograd_meta);
      return recovered_tensor;
136 137 138
    }
  }

139 140 141 142 143 144 145 146 147 148 149 150
  void check_inplace_version() {
    if (no_need_buffer_) {
      VLOG(6) << "There's no need to check inplace_version because "
                 "no_need_buffer_ is true.";
      return;
    }
    if (intermidiate_tensor_.impl() &&
        phi::DenseTensor::classof(intermidiate_tensor_.impl().get())) {
      phi::DenseTensor* dense_tensor =
          static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get());
      auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();

151 152
      uint32_t wrapper_version_snapshot = inplace_version_snapshot_;
      uint32_t tensor_version = inplace_version_counter.CurrentVersion();
153
      PADDLE_ENFORCE_EQ(
154
          tensor_version, wrapper_version_snapshot,
155 156 157 158 159 160 161
          paddle::platform::errors::PermissionDenied(
              "Tensor '%s' used in gradient computation has been "
              "modified by an inplace operation. "
              "Its version is %d but the expected version is %d. "
              "Please fix your code to void calling an inplace operator "
              "after using the Tensor which will used in gradient "
              "computation.",
162 163 164
              intermidiate_tensor_.name(), tensor_version,
              wrapper_version_snapshot));
      VLOG(6) << " The wrapper_version_snapshot of Tensor '"
165
              << intermidiate_tensor_.name() << "' is [ "
166 167 168 169
              << wrapper_version_snapshot << " ]";
      VLOG(6) << " The tensor_version of Tensor '"
              << intermidiate_tensor_.name() << "' is [ " << tensor_version
              << " ]";
170 171 172
    }
  }

173 174
  void clear() { intermidiate_tensor_.reset(); }

175 176
 private:
  bool full_reserved_ = false;
177
  bool no_need_buffer_ = false;
178
  paddle::experimental::Tensor intermidiate_tensor_;
179
  std::weak_ptr<egr::GradNodeBase> weak_grad_node_;
180
  uint32_t inplace_version_snapshot_ = 0;
181 182
};
}  // namespace egr