eager_tensor.h 11.2 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
17
#include "paddle/fluid/framework/phi_utils.h"
J
Jiabin Yang 已提交
18 19
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
20
// Phi deps
21 22 23
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53

namespace egr {

/**
 * VariableCompatTensor class is used by Eager mode for now. It's painful to
 * do this in Eager Mode, the better choice is to design the special Tensor
 * directly in phi and use it in paddle::experimental::Tensor.
 * However, we have some special operators, and they use special input variable
 * type, such as vector<string>, unordered_map<wstring, int>, these type cannot
 * cover by DenseTensor or SparseTensor. So, we have to provide a compatible
 * Tensor type like variable to support these special input type. We should
 * remove this as soon as we finish the ResourceTensor in phi.
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in framework::Variable and
 * necessary overridden methods.
 *
 * Note: This class is only used to support types that cannot be supported by
 * the phi Tensor system temporarily. You CANNOT use this class to handle types
 * such as DenseTensor, SelectedRows, etc.
 **/
class VariableCompatTensor
    : public phi::TensorBase,
      public phi::TypeInfoTraits<phi::TensorBase, VariableCompatTensor> {
 public:
  template <typename T>
  const T& Get() const {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
54 55 56
    PADDLE_ENFORCE_NOT_NULL(
        holder_,
        paddle::platform::errors::NotFound("Variable is not initialized."));
57
    PADDLE_ENFORCE_EQ(
58 59
        holder_->Type(),
        paddle::framework::VarTypeTrait<T>::kId,
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        paddle::platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            paddle::framework::ToTypeName(
                paddle::framework::VarTypeTrait<T>::kId),
            paddle::framework::ToTypeName(holder_->Type())));
    return *static_cast<const T*>(holder_->Ptr());
  }

  bool IsInitialized() const { return holder_ != nullptr; }

  template <typename T>
  T* GetMutable() {
    if (!holder_) {
      holder_.reset(new PlaceholderImpl<T>());
    } else {
      PADDLE_ENFORCE_EQ(
76 77
          holder_->Type(),
          paddle::framework::VarTypeTrait<T>::kId,
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
          paddle::platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              paddle::framework::ToTypeName(
                  paddle::framework::VarTypeTrait<T>::kId),
              paddle::framework::ToTypeName(holder_->Type())));
    }
    return static_cast<T*>(holder_->Ptr());
  }

  template <typename T>
  bool IsType() const {
    return holder_ &&
           holder_->Type() == paddle::framework::VarTypeTrait<T>::kId;
  }

  void Clear() { holder_.reset(); }

  int Type() const {
96 97 98
    PADDLE_ENFORCE_NOT_NULL(
        holder_,
        paddle::platform::errors::NotFound("Variable is not initialized."));
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    return holder_->Type();
  }

  // necessary overridden methods

  static const char* name() { return "VariableCompatTensor"; }

  ~VariableCompatTensor() override = default;

  int64_t numel() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `numel` method."));
  }

  const phi::DDim& dims() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dims` method."));
  }

  phi::DataType dtype() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dtype` method."));
  }

  phi::DataLayout layout() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `layout` method."));
  }

  const phi::Place& place() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `place` method."));
  }

  bool valid() const override { return IsInitialized(); }

  bool initialized() const override { return IsInitialized(); }

137 138
  void* AllocateFrom(phi::Allocator* allocator,
                     phi::DataType dtype,
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
                     size_t requested_size = 0) override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `AllocateFrom` method."));
  }

 private:
  struct Placeholder {
    virtual ~Placeholder() PADDLE_MAY_THROW {}

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

    void* ptr_;
    int type_;
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
    PlaceholderImpl() {
      this->Init(&obj_, paddle::framework::VarTypeTrait<T>::kId);
    }

   private:
    T obj_;
  };

  // pointers to a PlaceholderImpl object indeed.
  std::shared_ptr<Placeholder> holder_;
};

inline bool IsVariableCompatTensor(const paddle::experimental::Tensor& tensor) {
  return VariableCompatTensor::classof(tensor.impl().get());
}

J
Jiabin Yang 已提交
185 186
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
187 188 189 190 191 192
 * Mode, the better choice is to use paddle::experimental::Tensor directly.
 * However, we have a punch of nested kernel code, and they use
 * paddle::framework::Variable in inner logic code. So, we have to provide
 * variable in paddle::framework::ExecutionContext to support it. We should
 * remove this as soon as we finish our latest Phi Lib, and use
 * paddle::experimental::Tensor instead.
J
Jiabin Yang 已提交
193 194 195 196 197 198
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
 * paddle::experimental::Tensor with access method of
 * paddle::framework::Variable no more members are acceptable.
 * **/
199
class EagerVariable final {
J
Jiabin Yang 已提交
200
 public:
201 202
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
203
  EagerVariable() = default;
204

205
  explicit EagerVariable(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
206

207
  explicit EagerVariable(const paddle::experimental::Tensor& tensor)
208 209 210
      : name_(tensor.name()) {
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
211
        ConstructVariableFromTensor<phi::DenseTensor>(tensor);
212
      } else if (tensor.is_selected_rows()) {
213
        ConstructVariableFromTensor<phi::SelectedRows>(tensor);
214 215 216 217 218 219 220 221
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Vocab>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Vocab>(tensor);
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Strings>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Strings>(tensor);
J
Jiabin Yang 已提交
222 223
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
224
            "Unrecognized egr::EagerVariable type, only "
225
            "DenseTensor and SelectedRows are supported for now."));
J
Jiabin Yang 已提交
226
      }
227
    } else {
228
      VLOG(6) << "Build Empty EagerVariable with name " << name_;
J
Jiabin Yang 已提交
229 230
    }
  }
231

232 233
  /** Part 11: Construct paddle::framework::Variable with phi::Tensor **/
  std::shared_ptr<phi::TensorBase> GetTensorBase() {
234
    // Construct allocation only once.
235
    if (var_.IsInitialized()) {
236 237
      if (var_.IsType<paddle::framework::LoDTensor>() ||
          var_.IsType<paddle::framework::Tensor>()) {
238 239 240
        return SetImplWithLegacyTensor<phi::DenseTensor>();
      } else if (var_.IsType<phi::SelectedRows>()) {
        return SetImplWithLegacyTensor<phi::SelectedRows>();
J
Jiabin Yang 已提交
241
      } else {
242 243
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
244
            "from EagerVariable, only LoDTensor and "
245
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
246
      }
247 248
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
249
          "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
250 251
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
252 253
    }
  }
254 255 256
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
257 258 259

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

260
  const std::string& name() const { return name_; }
261

262
  void set_name(const std::string& name) { name_ = name; }
263

J
Jiabin Yang 已提交
264
 private:
J
Jiabin Yang 已提交
265
  template <typename VarType>
266
  std::shared_ptr<phi::TensorBase> SetImplWithLegacyTensor() {
J
Jiabin Yang 已提交
267
    const auto& framework_tensor = var_.Get<VarType>();
268
    VLOG(8) << "Sync Var to tensor for: " << name();
J
Jiabin Yang 已提交
269
    return std::make_shared<VarType>(framework_tensor);
270
  }
271

J
Jiabin Yang 已提交
272
  template <typename VarType>
273
  void ConstructVariableFromTensor(const paddle::experimental::Tensor& tensor) {
J
Jiabin Yang 已提交
274
    auto* framework_tensor = var_.GetMutable<VarType>();
275
    // Contruct framework::Tensor from egr::EagerVariable
J
Jiabin Yang 已提交
276
    auto tensor_dense = std::dynamic_pointer_cast<VarType>(tensor.impl());
277
    PADDLE_ENFORCE_EQ(
278 279
        (tensor_dense.get() && tensor_dense),
        true,
280
        paddle::platform::errors::Fatal(
281
            "Tensor %s does not hold phi::SelectedRows or phi::DenseTensor. "
282 283 284 285 286 287
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = *tensor_dense;
  }

288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
  template <typename VarType>
  void ConstructVariableFromCompatTensor(
      const paddle::experimental::Tensor& tensor) {
    auto* framework_holder = var_.GetMutable<VarType>();
    // Contruct framework::Tensor from egr::EagerVariable
    auto* compat_tensor =
        static_cast<VariableCompatTensor*>(tensor.impl().get());
    PADDLE_ENFORCE_NOT_NULL(compat_tensor,
                            paddle::platform::errors::Fatal(
                                "Tensor %s holds empty impl, this should not "
                                "happend since we should "
                                "treat all kinds of tensor as what they are.",
                                tensor.name()));
    *framework_holder = compat_tensor->Get<VarType>();
  }

304 305
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
306 307 308
  paddle::framework::Variable var_;
};
}  // namespace egr