variable.h 5.5 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yi Wang 已提交
14 15
#pragma once

Y
Yi Wang 已提交
16
#include <memory>
17
#include <string>
Y
Yi Wang 已提交
18
#include <typeindex>
Y
Yi Wang 已提交
19 20
#include <typeinfo>

21
#include "paddle/fluid/framework/selected_rows.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/framework/var_type_traits.h"
Y
Yi Wang 已提交
23 24 25 26 27 28 29
namespace paddle {
namespace framework {

class Variable {
 public:
  template <typename T>
  const T& Get() const {
S
sneaxiy 已提交
30 31 32
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
33 34 35 36 37 38 39
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
    PADDLE_ENFORCE_EQ(
        holder_->Type(), VarTypeTrait<T>::kId,
        platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
40 41 42
    return *static_cast<const T*>(holder_->Ptr());
  }

43 44
  bool IsInitialized() const { return holder_ != nullptr; }

Y
Yi Wang 已提交
45 46
  template <typename T>
  T* GetMutable() {
X
Xin Pan 已提交
47
    if (!holder_) {
S
sneaxiy 已提交
48
      holder_.reset(new PlaceholderImpl<T>());
X
Xin Pan 已提交
49
    } else {
50 51 52 53 54
      PADDLE_ENFORCE_EQ(
          holder_->Type(), VarTypeTrait<T>::kId,
          platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
55
    }
Y
Yi Wang 已提交
56
    return static_cast<T*>(holder_->Ptr());
Y
Yi Wang 已提交
57 58
  }

Y
Yu Yang 已提交
59 60
  template <typename T>
  bool IsType() const {
S
sneaxiy 已提交
61
    return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
Y
Yu Yang 已提交
62 63
  }

Y
Yu Yang 已提交
64 65
  void Clear() { holder_.reset(); }

S
sneaxiy 已提交
66
  int Type() const {
67 68
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
69 70 71
    return holder_->Type();
  }

72 73 74 75 76 77 78 79 80 81
  /**
   * The internal of two Variables share the same Placeholder whose type can be
   * Tensor, LoDTensor, SelectedRows, LoDTensorArray, etc.
   *
   * NOTE(liym27): In dynamic mode, sharing the same Placeholder also means
   * share the same TensorInplaceVersion, which is very important for inplace
   * operations.
   */
  void SharePlaceholderWith(const Variable& var);

82 83 84 85 86 87 88 89 90
 private:
  // This method hides type T, so it doesn't appear as a template parameter of
  // Variable.
  framework::TensorInplaceVersion* InplaceVersionCounter();

 public:
  uint32_t CurrentInplaceVersion();
  void BumpInplaceVersion();

Y
Yi Wang 已提交
91 92
 private:
  struct Placeholder {
Z
Zeng Jinle 已提交
93
    virtual ~Placeholder() PADDLE_MAY_THROW {}
S
sneaxiy 已提交
94 95 96 97 98 99

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
S
sneaxiy 已提交
100 101 102 103 104
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

S
sneaxiy 已提交
105 106
    void* ptr_;
    int type_;
Y
Yi Wang 已提交
107 108 109 110 111 112
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
S
sneaxiy 已提交
113 114 115
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
S
sneaxiy 已提交
116
    PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
Y
Yi Wang 已提交
117

S
sneaxiy 已提交
118 119
   private:
    T obj_;
Y
Yi Wang 已提交
120 121
  };

S
sneaxiy 已提交
122
  // pointers to a PlaceholderImpl object indeed.
123
  std::shared_ptr<Placeholder> holder_;
Y
Yi Wang 已提交
124 125
};

126 127 128 129 130 131 132 133
inline void Variable::SharePlaceholderWith(const Variable& var) {
  PADDLE_ENFORCE_EQ(var.IsInitialized(), true,
                    platform::errors::PreconditionNotMet(
                        "Variable holds no memory. "
                        "Call Variable::GetMutable() firstly."));
  holder_ = var.holder_;
}

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
  framework::TensorInplaceVersion* version_counter_ptr(nullptr);
  if (IsType<framework::LoDTensor>()) {
    version_counter_ptr =
        &GetMutable<framework::LoDTensor>()->InplaceVersionCounter();
  } else if (IsType<framework::Tensor>()) {
    version_counter_ptr =
        &GetMutable<framework::Tensor>()->InplaceVersionCounter();

  } else if (IsType<framework::SelectedRows>()) {
    version_counter_ptr = &GetMutable<framework::SelectedRows>()
                               ->mutable_value()
                               ->InplaceVersionCounter();
  } else {
    VLOG(4) << "Only supports Tensor, LoDTensor, SelectedRows to have "
               "TensorInplaceVersion, but received type "
            << platform::demangle(framework::ToTypeName(Type()));
  }
  return version_counter_ptr;
}

inline uint32_t Variable::CurrentInplaceVersion() {
  auto version_counter_ptr = InplaceVersionCounter();
  if (version_counter_ptr) {
    return version_counter_ptr->CurrentVersion();
  } else {
    return 0;
  }
}

inline void Variable::BumpInplaceVersion() {
  auto version_counter_ptr = InplaceVersionCounter();
  if (version_counter_ptr) {
    return version_counter_ptr->Bump();
  } else {
    VLOG(4) << "Only supports Tensor, LoDTensor, SelectedRows to have "
               "TensorInplaceVersion, but received type "
            << platform::demangle(framework::ToTypeName(Type()));
  }
}
Y
Yi Wang 已提交
174 175
}  // namespace framework
}  // namespace paddle