eager_tensor.h 11.2 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
// framework deps
17
#include "paddle/fluid/framework/phi_utils.h"
J
Jiabin Yang 已提交
18 19
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
20
// Phi deps
21 22 23
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179

namespace egr {

/**
 * VariableCompatTensor class is used by Eager mode for now. It's painful to
 * do this in Eager Mode, the better choice is to design the special Tensor
 * directly in phi and use it in paddle::experimental::Tensor.
 * However, we have some special operators, and they use special input variable
 * type, such as vector<string>, unordered_map<wstring, int>, these type cannot
 * cover by DenseTensor or SparseTensor. So, we have to provide a compatible
 * Tensor type like variable to support these special input type. We should
 * remove this as soon as we finish the ResourceTensor in phi.
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in framework::Variable and
 * necessary overridden methods.
 *
 * Note: This class is only used to support types that cannot be supported by
 * the phi Tensor system temporarily. You CANNOT use this class to handle types
 * such as DenseTensor, SelectedRows, etc.
 **/
class VariableCompatTensor
    : public phi::TensorBase,
      public phi::TypeInfoTraits<phi::TensorBase, VariableCompatTensor> {
 public:
  template <typename T>
  const T& Get() const {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
    PADDLE_ENFORCE_NOT_NULL(holder_, paddle::platform::errors::NotFound(
                                         "Variable is not initialized."));
    PADDLE_ENFORCE_EQ(
        holder_->Type(), paddle::framework::VarTypeTrait<T>::kId,
        paddle::platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            paddle::framework::ToTypeName(
                paddle::framework::VarTypeTrait<T>::kId),
            paddle::framework::ToTypeName(holder_->Type())));
    return *static_cast<const T*>(holder_->Ptr());
  }

  bool IsInitialized() const { return holder_ != nullptr; }

  template <typename T>
  T* GetMutable() {
    if (!holder_) {
      holder_.reset(new PlaceholderImpl<T>());
    } else {
      PADDLE_ENFORCE_EQ(
          holder_->Type(), paddle::framework::VarTypeTrait<T>::kId,
          paddle::platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              paddle::framework::ToTypeName(
                  paddle::framework::VarTypeTrait<T>::kId),
              paddle::framework::ToTypeName(holder_->Type())));
    }
    return static_cast<T*>(holder_->Ptr());
  }

  template <typename T>
  bool IsType() const {
    return holder_ &&
           holder_->Type() == paddle::framework::VarTypeTrait<T>::kId;
  }

  void Clear() { holder_.reset(); }

  int Type() const {
    PADDLE_ENFORCE_NOT_NULL(holder_, paddle::platform::errors::NotFound(
                                         "Variable is not initialized."));
    return holder_->Type();
  }

  // necessary overridden methods

  static const char* name() { return "VariableCompatTensor"; }

  ~VariableCompatTensor() override = default;

  int64_t numel() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `numel` method."));
  }

  const phi::DDim& dims() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dims` method."));
  }

  phi::DataType dtype() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `dtype` method."));
  }

  phi::DataLayout layout() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `layout` method."));
  }

  const phi::Place& place() const override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `place` method."));
  }

  bool valid() const override { return IsInitialized(); }

  bool initialized() const override { return IsInitialized(); }

  void* AllocateFrom(phi::Allocator* allocator, phi::DataType dtype,
                     size_t requested_size = 0) override {
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "VariableCompatTensor does not support `AllocateFrom` method."));
  }

 private:
  struct Placeholder {
    virtual ~Placeholder() PADDLE_MAY_THROW {}

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

    void* ptr_;
    int type_;
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
    static_assert(
        paddle::framework::IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
    PlaceholderImpl() {
      this->Init(&obj_, paddle::framework::VarTypeTrait<T>::kId);
    }

   private:
    T obj_;
  };

  // pointers to a PlaceholderImpl object indeed.
  std::shared_ptr<Placeholder> holder_;
};

inline bool IsVariableCompatTensor(const paddle::experimental::Tensor& tensor) {
  return VariableCompatTensor::classof(tensor.impl().get());
}

J
Jiabin Yang 已提交
180 181
/**
 * This class is used by Eager mode for now. It's painful to do this in Eager
182 183 184 185 186 187
 * Mode, the better choice is to use paddle::experimental::Tensor directly.
 * However, we have a punch of nested kernel code, and they use
 * paddle::framework::Variable in inner logic code. So, we have to provide
 * variable in paddle::framework::ExecutionContext to support it. We should
 * remove this as soon as we finish our latest Phi Lib, and use
 * paddle::experimental::Tensor instead.
J
Jiabin Yang 已提交
188 189 190 191 192 193
 *
 * Note: Keep this class as clean as possible.
 * This class should only support method declared in
 * paddle::experimental::Tensor with access method of
 * paddle::framework::Variable no more members are acceptable.
 * **/
194
class EagerVariable final {
J
Jiabin Yang 已提交
195
 public:
196 197
  /* Default constructor and name constructor should only be used for contruct
   * output and in fluid*/
198
  EagerVariable() = default;
199

200
  explicit EagerVariable(const std::string& name) : name_(name) {}
J
Jiabin Yang 已提交
201

202
  explicit EagerVariable(const paddle::experimental::Tensor& tensor)
203 204 205
      : name_(tensor.name()) {
    if (tensor.defined()) {
      if (tensor.is_dense_tensor()) {
206
        ConstructVariableFromTensor<phi::DenseTensor>(tensor);
207
      } else if (tensor.is_selected_rows()) {
208
        ConstructVariableFromTensor<phi::SelectedRows>(tensor);
209 210 211 212 213 214 215 216
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Vocab>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Vocab>(tensor);
      } else if (IsVariableCompatTensor(tensor) &&
                 static_cast<const VariableCompatTensor*>(tensor.impl().get())
                     ->IsType<paddle::framework::Strings>()) {
        ConstructVariableFromCompatTensor<paddle::framework::Strings>(tensor);
J
Jiabin Yang 已提交
217 218
      } else {
        PADDLE_THROW(paddle::platform::errors::Fatal(
219
            "Unrecognized egr::EagerVariable type, only "
220
            "DenseTensor and SelectedRows are supported for now."));
J
Jiabin Yang 已提交
221
      }
222
    } else {
223
      VLOG(6) << "Build Empty EagerVariable with name " << name_;
J
Jiabin Yang 已提交
224 225
    }
  }
226

227 228
  /** Part 11: Construct paddle::framework::Variable with phi::Tensor **/
  std::shared_ptr<phi::TensorBase> GetTensorBase() {
229
    // Construct allocation only once.
230
    if (var_.IsInitialized()) {
231 232
      if (var_.IsType<paddle::framework::LoDTensor>() ||
          var_.IsType<paddle::framework::Tensor>()) {
233 234 235
        return SetImplWithLegacyTensor<phi::DenseTensor>();
      } else if (var_.IsType<phi::SelectedRows>()) {
        return SetImplWithLegacyTensor<phi::SelectedRows>();
J
Jiabin Yang 已提交
236
      } else {
237 238
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Unable to fetch underlying tensor "
239
            "from EagerVariable, only LoDTensor and "
240
            "Tensor are supported for now"));
J
Jiabin Yang 已提交
241
      }
242 243
    } else {
      PADDLE_THROW(paddle::platform::errors::Fatal(
244
          "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
245 246
          "not initialized!",
          name()));
J
Jiabin Yang 已提交
247 248
    }
  }
249 250 251
  const paddle::framework::Variable& Var() const { return var_; }

  paddle::framework::Variable* MutableVar() { return &var_; }
J
Jiabin Yang 已提交
252 253 254

  void ResetVar(const paddle::framework::Variable& src) { var_ = src; }

255
  const std::string& name() const { return name_; }
256

257
  void set_name(const std::string& name) { name_ = name; }
258

J
Jiabin Yang 已提交
259
 private:
J
Jiabin Yang 已提交
260
  template <typename VarType>
261
  std::shared_ptr<phi::TensorBase> SetImplWithLegacyTensor() {
J
Jiabin Yang 已提交
262
    const auto& framework_tensor = var_.Get<VarType>();
263
    VLOG(8) << "Sync Var to tensor for: " << name();
J
Jiabin Yang 已提交
264
    return std::make_shared<VarType>(framework_tensor);
265
  }
266

J
Jiabin Yang 已提交
267
  template <typename VarType>
268
  void ConstructVariableFromTensor(const paddle::experimental::Tensor& tensor) {
J
Jiabin Yang 已提交
269
    auto* framework_tensor = var_.GetMutable<VarType>();
270
    // Contruct framework::Tensor from egr::EagerVariable
J
Jiabin Yang 已提交
271
    auto tensor_dense = std::dynamic_pointer_cast<VarType>(tensor.impl());
272 273 274
    PADDLE_ENFORCE_EQ(
        (tensor_dense.get() && tensor_dense), true,
        paddle::platform::errors::Fatal(
275
            "Tensor %s does not hold phi::SelectedRows or phi::DenseTensor. "
276 277 278 279 280 281
            "Or it holds empty impl, this should not happend since we should "
            "treat all kinds of tensor as what they are.",
            tensor.name()));
    *framework_tensor = *tensor_dense;
  }

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
  template <typename VarType>
  void ConstructVariableFromCompatTensor(
      const paddle::experimental::Tensor& tensor) {
    auto* framework_holder = var_.GetMutable<VarType>();
    // Contruct framework::Tensor from egr::EagerVariable
    auto* compat_tensor =
        static_cast<VariableCompatTensor*>(tensor.impl().get());
    PADDLE_ENFORCE_NOT_NULL(compat_tensor,
                            paddle::platform::errors::Fatal(
                                "Tensor %s holds empty impl, this should not "
                                "happend since we should "
                                "treat all kinds of tensor as what they are.",
                                tensor.name()));
    *framework_holder = compat_tensor->Get<VarType>();
  }

298 299
 private:
  std::string name_{""};
J
Jiabin Yang 已提交
300 301 302
  paddle::framework::Variable var_;
};
}  // namespace egr