eager_tensor.h 11.7 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
17
#include "paddle/fluid/framework/phi_utils.h"
J
Jiabin Yang 已提交
18 19
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
20
// Phi deps
21 22
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/compat/convert_utils.h"
23
#include "paddle/phi/core/macros.h"
24 25 26 27 28 29

namespace egr {

/**
 * VariableCompatTensor class is used by Eager mode for now. It's painful to
 * do this in Eager Mode, the better choice is to design the special Tensor
30
 * directly in phi and use it in paddle::Tensor.
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
 * However, we have some special operators, and they use special input variable
 * type, such as vector<string>, unordered_map<wstring, int>, these type cannot
 * cover by DenseTensor or SparseTensor. So, we have to provide a compatible
 * Tensor type like variable to support these special input type. We should
 * remove this as soon as we finish the ResourceTensor in phi.
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in framework::Variable and
 * necessary overridden methods.
 *
 * Note: This class is only used to support types that cannot be supported by
 * the phi Tensor system temporarily. You CANNOT use this class to handle types
 * such as DenseTensor, SelectedRows, etc.
 **/
class VariableCompatTensor
    : public phi::TensorBase,
      public phi::TypeInfoTraits<phi::TensorBase, VariableCompatTensor> {
 public:
  template <typename T>
  const T& Get() const {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
54 55 56
    PADDLE_ENFORCE_NOT_NULL(
        holder_,
        paddle::platform::errors::NotFound("Variable is not initialized."));
57
    PADDLE_ENFORCE_EQ(
58 59
        holder_->Type(),
        paddle::framework::VarTypeTrait<T>::kId,
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        paddle::platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            paddle::framework::ToTypeName(
                paddle::framework::VarTypeTrait<T>::kId),
            paddle::framework::ToTypeName(holder_->Type())));
    return *static_cast<const T*>(holder_->Ptr());
  }

  bool IsInitialized() const { return holder_ != nullptr; }

  template <typename T>
  T* GetMutable() {
    if (!holder_) {
      holder_.reset(new PlaceholderImpl<T>());
    } else {
      PADDLE_ENFORCE_EQ(
76 77
          holder_->Type(),
          paddle::framework::VarTypeTrait<T>::kId,
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
          paddle::platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              paddle::framework::ToTypeName(
                  paddle::framework::VarTypeTrait<T>::kId),
              paddle::framework::ToTypeName(holder_->Type())));
    }
    return static_cast<T*>(holder_->Ptr());
  }

  template <typename T>
  bool IsType() const {
    return holder_ &&
           holder_->Type() == paddle::framework::VarTypeTrait<T>::kId;
  }

  void Clear() { holder_.reset(); }

  int Type() const {
96 97 98
    PADDLE_ENFORCE_NOT_NULL(
        holder_,
        paddle::platform::errors::NotFound("Variable is not initialized."));
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    return holder_->Type();
  }

  // necessary overridden methods

  static const char* name() { return "VariableCompatTensor"; }

  ~VariableCompatTensor() override = default;

  int64_t numel() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `numel` method."));
  }

  const phi::DDim& dims() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dims` method."));
  }

  phi::DataType dtype() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dtype` method."));
  }

  phi::DataLayout layout() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `layout` method."));
  }

  const phi::Place& place() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `place` method."));
  }

  bool valid() const override { return IsInitialized(); }

  bool initialized() const override { return IsInitialized(); }

137 138 139 140
  void* AllocateFrom(phi::Allocator* allocator UNUSED,
                     phi::DataType dtype UNUSED,
                     size_t requested_size UNUSED = 0,
                     bool fake_alloc UNUSED = false) override {
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `AllocateFrom` method."));
  }

 private:
  struct Placeholder {
    virtual ~Placeholder() PADDLE_MAY_THROW {}

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

    void* ptr_;
    int type_;
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
    PlaceholderImpl() {
      this->Init(&obj_, paddle::framework::VarTypeTrait<T>::kId);
    }

   private:
    T obj_;
  };

  // pointers to a PlaceholderImpl object indeed.
  std::shared_ptr<Placeholder> holder_;
};

182
inline bool IsVariableCompatTensor(const paddle::Tensor& tensor) {
183 184 185
  return VariableCompatTensor::classof(tensor.impl().get());
}

J
Jiabin Yang 已提交
186 187
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
188
 * Mode, the better choice is to use paddle::Tensor directly.
189 190 191 192
 * However, we have a punch of nested kernel code, and they use
 * paddle::framework::Variable in inner logic code. So, we have to provide
 * variable in paddle::framework::ExecutionContext to support it. We should
 * remove this as soon as we finish our latest Phi Lib, and use
193
 * paddle::Tensor instead.
J
Jiabin Yang 已提交
194 195 196
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
197
 * paddle::Tensor with access method of
J
Jiabin Yang 已提交
198 199
 * paddle::framework::Variable no more members are acceptable.
 * **/
200
class EagerVariable final {
J
Jiabin Yang 已提交
201
 public:
202 203
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
204
  EagerVariable() = default;
205

206
  explicit EagerVariable(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
207

208
  explicit EagerVariable(const paddle::Tensor& tensor) : name_(tensor.name()) {
209 210
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
211
        ConstructVariableFromTensor<phi::DenseTensor>(tensor);
212
        src_tensor_ = tensor.impl();
213
      } else if (tensor.is_selected_rows()) {
214
        ConstructVariableFromTensor<phi::SelectedRows>(tensor);
215 216 217 218 219 220 221 222
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Vocab>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Vocab>(tensor);
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Strings>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Strings>(tensor);
J
Jiabin Yang 已提交
223 224
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
225
            "Unrecognized egr::EagerVariable type, only "
226
            "DenseTensor and SelectedRows are supported for now."));
J
Jiabin Yang 已提交
227
      }
228
    } else {
229
      VLOG(6) << "Build Empty EagerVariable with name " << name_;
J
Jiabin Yang 已提交
230 231
    }
  }
232

233 234 235 236 237 238 239 240 241 242 243 244 245
  ~EagerVariable() {
    if (src_tensor_) {
      auto* framework_tensor = var_.GetMutable<phi::DenseTensor>();
      auto tensor_dense = static_cast<phi::DenseTensor*>(src_tensor_.get());
      if (framework_tensor->memory_size() > 0 &&
          (!paddle::platform::is_same_place(framework_tensor->place(),
                                            tensor_dense->place()) ||
           framework_tensor->dtype() != tensor_dense->dtype())) {
        tensor_dense->ShareBufferWith(*framework_tensor);
      }
    }
  }

246 247
  /** Part 11: Construct paddle::framework::Variable with phi::Tensor **/
  std::shared_ptr<phi::TensorBase> GetTensorBase() {
248
    // Construct allocation only once.
249
    if (var_.IsInitialized()) {
250
      if (var_.IsType<phi::DenseTensor>() || var_.IsType<phi::DenseTensor>()) {
251 252 253
        return SetImplWithLegacyTensor<phi::DenseTensor>();
      } else if (var_.IsType<phi::SelectedRows>()) {
        return SetImplWithLegacyTensor<phi::SelectedRows>();
J
Jiabin Yang 已提交
254
      } else {
255 256
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
257
            "from EagerVariable, only LoDTensor and "
258
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
259
      }
260 261
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
262
          "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
263 264
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
265 266
    }
  }
267 268 269
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
270 271 272

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

273
  const std::string& name() const { return name_; }
274

275
  void set_name(const std::string& name) { name_ = name; }
276

J
Jiabin Yang 已提交
277
 private:
J
Jiabin Yang 已提交
278
  template <typename VarType>
279
  std::shared_ptr<phi::TensorBase> SetImplWithLegacyTensor() {
J
Jiabin Yang 已提交
280
    const auto& framework_tensor = var_.Get<VarType>();
281
    VLOG(8) << "Sync Var to tensor for: " << name();
J
Jiabin Yang 已提交
282
    return std::make_shared<VarType>(framework_tensor);
283
  }
284

J
Jiabin Yang 已提交
285
  template <typename VarType>
286
  void ConstructVariableFromTensor(const paddle::Tensor& tensor) {
J
Jiabin Yang 已提交
287
    auto* framework_tensor = var_.GetMutable<VarType>();
288
    // Contruct phi::DenseTensor from egr::EagerVariable
J
Jiabin Yang 已提交
289
    auto tensor_dense = std::dynamic_pointer_cast<VarType>(tensor.impl());
290
    PADDLE_ENFORCE_EQ(
291 292
        (tensor_dense.get() && tensor_dense),
        true,
293
        paddle::platform::errors::Fatal(
294
            "Tensor %s does not hold phi::SelectedRows or phi::DenseTensor. "
295 296 297 298 299 300
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = *tensor_dense;
  }

301
  template <typename VarType>
302
  void ConstructVariableFromCompatTensor(const paddle::Tensor& tensor) {
303
    auto* framework_holder = var_.GetMutable<VarType>();
304
    // Contruct phi::DenseTensor from egr::EagerVariable
305 306 307 308 309 310 311 312 313 314 315
    auto* compat_tensor =
        static_cast<VariableCompatTensor*>(tensor.impl().get());
    PADDLE_ENFORCE_NOT_NULL(compat_tensor,
                            paddle::platform::errors::Fatal(
                                "Tensor %s holds empty impl, this should not "
                                "happend since we should "
                                "treat all kinds of tensor as what they are.",
                                tensor.name()));
    *framework_holder = compat_tensor->Get<VarType>();
  }

316 317
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
318
  paddle::framework::Variable var_;
319
  std::shared_ptr<phi::TensorBase> src_tensor_;
J
Jiabin Yang 已提交
320 321
};
}  // namespace egr