eager_tensor.h 11.8 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
17
#include "paddle/fluid/framework/phi_utils.h"
J
Jiabin Yang 已提交
18 19
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
20
// Phi deps
21 22 23
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53

namespace egr {

/**
 * VariableCompatTensor class is used by Eager mode for now. It's painful to
 * do this in Eager Mode, the better choice is to design the special Tensor
 * directly in phi and use it in paddle::experimental::Tensor.
 * However, we have some special operators, and they use special input variable
 * type, such as vector<string>, unordered_map<wstring, int>, these type cannot
 * cover by DenseTensor or SparseTensor. So, we have to provide a compatible
 * Tensor type like variable to support these special input type. We should
 * remove this as soon as we finish the ResourceTensor in phi.
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in framework::Variable and
 * necessary overridden methods.
 *
 * Note: This class is only used to support types that cannot be supported by
 * the phi Tensor system temporarily. You CANNOT use this class to handle types
 * such as DenseTensor, SelectedRows, etc.
 **/
class VariableCompatTensor
    : public phi::TensorBase,
      public phi::TypeInfoTraits<phi::TensorBase, VariableCompatTensor> {
 public:
  template <typename T>
  const T& Get() const {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
54 55 56
    PADDLE_ENFORCE_NOT_NULL(
        holder_,
        paddle::platform::errors::NotFound("Variable is not initialized."));
57
    PADDLE_ENFORCE_EQ(
58 59
        holder_->Type(),
        paddle::framework::VarTypeTrait<T>::kId,
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        paddle::platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            paddle::framework::ToTypeName(
                paddle::framework::VarTypeTrait<T>::kId),
            paddle::framework::ToTypeName(holder_->Type())));
    return *static_cast<const T*>(holder_->Ptr());
  }

  bool IsInitialized() const { return holder_ != nullptr; }

  template <typename T>
  T* GetMutable() {
    if (!holder_) {
      holder_.reset(new PlaceholderImpl<T>());
    } else {
      PADDLE_ENFORCE_EQ(
76 77
          holder_->Type(),
          paddle::framework::VarTypeTrait<T>::kId,
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
          paddle::platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              paddle::framework::ToTypeName(
                  paddle::framework::VarTypeTrait<T>::kId),
              paddle::framework::ToTypeName(holder_->Type())));
    }
    return static_cast<T*>(holder_->Ptr());
  }

  template <typename T>
  bool IsType() const {
    return holder_ &&
           holder_->Type() == paddle::framework::VarTypeTrait<T>::kId;
  }

  void Clear() { holder_.reset(); }

  int Type() const {
96 97 98
    PADDLE_ENFORCE_NOT_NULL(
        holder_,
        paddle::platform::errors::NotFound("Variable is not initialized."));
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    return holder_->Type();
  }

  // necessary overridden methods

  static const char* name() { return "VariableCompatTensor"; }

  ~VariableCompatTensor() override = default;

  int64_t numel() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `numel` method."));
  }

  const phi::DDim& dims() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dims` method."));
  }

  phi::DataType dtype() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dtype` method."));
  }

  phi::DataLayout layout() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `layout` method."));
  }

  const phi::Place& place() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `place` method."));
  }

  bool valid() const override { return IsInitialized(); }

  bool initialized() const override { return IsInitialized(); }

137 138
  void* AllocateFrom(phi::Allocator* allocator,
                     phi::DataType dtype,
139 140
                     size_t requested_size = 0,
                     bool fake_alloc = false) override {
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `AllocateFrom` method."));
  }

 private:
  struct Placeholder {
    virtual ~Placeholder() PADDLE_MAY_THROW {}

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

    void* ptr_;
    int type_;
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
    PlaceholderImpl() {
      this->Init(&obj_, paddle::framework::VarTypeTrait<T>::kId);
    }

   private:
    T obj_;
  };

  // pointers to a PlaceholderImpl object indeed.
  std::shared_ptr<Placeholder> holder_;
};

inline bool IsVariableCompatTensor(const paddle::experimental::Tensor& tensor) {
  return VariableCompatTensor::classof(tensor.impl().get());
}

J
Jiabin Yang 已提交
186 187
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
188 189 190 191 192 193
 * Mode, the better choice is to use paddle::experimental::Tensor directly.
 * However, we have a punch of nested kernel code, and they use
 * paddle::framework::Variable in inner logic code. So, we have to provide
 * variable in paddle::framework::ExecutionContext to support it. We should
 * remove this as soon as we finish our latest Phi Lib, and use
 * paddle::experimental::Tensor instead.
J
Jiabin Yang 已提交
194 195 196 197 198 199
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
 * paddle::experimental::Tensor with access method of
 * paddle::framework::Variable no more members are acceptable.
 * **/
200
class EagerVariable final {
J
Jiabin Yang 已提交
201
 public:
202 203
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
204
  EagerVariable() = default;
205

206
  explicit EagerVariable(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
207

208
  explicit EagerVariable(const paddle::experimental::Tensor& tensor)
209 210 211
      : name_(tensor.name()) {
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
212
        ConstructVariableFromTensor<phi::DenseTensor>(tensor);
213
        src_tensor_ = tensor.impl();
214
      } else if (tensor.is_selected_rows()) {
215
        ConstructVariableFromTensor<phi::SelectedRows>(tensor);
216 217 218 219 220 221 222 223
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Vocab>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Vocab>(tensor);
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Strings>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Strings>(tensor);
J
Jiabin Yang 已提交
224 225
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
226
            "Unrecognized egr::EagerVariable type, only "
227
            "DenseTensor and SelectedRows are supported for now."));
J
Jiabin Yang 已提交
228
      }
229
    } else {
230
      VLOG(6) << "Build Empty EagerVariable with name " << name_;
J
Jiabin Yang 已提交
231 232
    }
  }
233

234 235 236 237 238 239 240 241 242 243 244 245 246
  ~EagerVariable() {
    if (src_tensor_) {
      auto* framework_tensor = var_.GetMutable<phi::DenseTensor>();
      auto tensor_dense = static_cast<phi::DenseTensor*>(src_tensor_.get());
      if (framework_tensor->memory_size() > 0 &&
          (!paddle::platform::is_same_place(framework_tensor->place(),
                                            tensor_dense->place()) ||
           framework_tensor->dtype() != tensor_dense->dtype())) {
        tensor_dense->ShareBufferWith(*framework_tensor);
      }
    }
  }

247 248
  /** Part 11: Construct paddle::framework::Variable with phi::Tensor **/
  std::shared_ptr<phi::TensorBase> GetTensorBase() {
249
    // Construct allocation only once.
250
    if (var_.IsInitialized()) {
251
      if (var_.IsType<phi::DenseTensor>() || var_.IsType<phi::DenseTensor>()) {
252 253 254
        return SetImplWithLegacyTensor<phi::DenseTensor>();
      } else if (var_.IsType<phi::SelectedRows>()) {
        return SetImplWithLegacyTensor<phi::SelectedRows>();
J
Jiabin Yang 已提交
255
      } else {
256 257
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
258
            "from EagerVariable, only LoDTensor and "
259
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
260
      }
261 262
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
263
          "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
264 265
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
266 267
    }
  }
268 269 270
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
271 272 273

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

274
  const std::string& name() const { return name_; }
275

276
  void set_name(const std::string& name) { name_ = name; }
277

J
Jiabin Yang 已提交
278
 private:
J
Jiabin Yang 已提交
279
  template <typename VarType>
280
  std::shared_ptr<phi::TensorBase> SetImplWithLegacyTensor() {
J
Jiabin Yang 已提交
281
    const auto& framework_tensor = var_.Get<VarType>();
282
    VLOG(8) << "Sync Var to tensor for: " << name();
J
Jiabin Yang 已提交
283
    return std::make_shared<VarType>(framework_tensor);
284
  }
285

J
Jiabin Yang 已提交
286
  template <typename VarType>
287
  void ConstructVariableFromTensor(const paddle::experimental::Tensor& tensor) {
J
Jiabin Yang 已提交
288
    auto* framework_tensor = var_.GetMutable<VarType>();
289
    // Contruct phi::DenseTensor from egr::EagerVariable
J
Jiabin Yang 已提交
290
    auto tensor_dense = std::dynamic_pointer_cast<VarType>(tensor.impl());
291
    PADDLE_ENFORCE_EQ(
292 293
        (tensor_dense.get() && tensor_dense),
        true,
294
        paddle::platform::errors::Fatal(
295
            "Tensor %s does not hold phi::SelectedRows or phi::DenseTensor. "
296 297 298 299 300 301
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = *tensor_dense;
  }

302 303 304 305
  template <typename VarType>
  void ConstructVariableFromCompatTensor(
      const paddle::experimental::Tensor& tensor) {
    auto* framework_holder = var_.GetMutable<VarType>();
306
    // Contruct phi::DenseTensor from egr::EagerVariable
307 308 309 310 311 312 313 314 315 316 317
    auto* compat_tensor =
        static_cast<VariableCompatTensor*>(tensor.impl().get());
    PADDLE_ENFORCE_NOT_NULL(compat_tensor,
                            paddle::platform::errors::Fatal(
                                "Tensor %s holds empty impl, this should not "
                                "happend since we should "
                                "treat all kinds of tensor as what they are.",
                                tensor.name()));
    *framework_holder = compat_tensor->Get<VarType>();
  }

318 319
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
320
  paddle::framework::Variable var_;
321
  std::shared_ptr<phi::TensorBase> src_tensor_;
J
Jiabin Yang 已提交
322 323
};
}  // namespace egr