variable.h 5.0 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yi Wang 已提交
14 15
#pragma once

Y
Yi Wang 已提交
16
#include <memory>
17
#include <string>
Y
Yi Wang 已提交
18
#include <typeindex>
Y
Yi Wang 已提交
19 20
#include <typeinfo>

21
#include "paddle/fluid/framework/selected_rows_utils.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/framework/var_type_traits.h"
Y
Yi Wang 已提交
23 24 25 26 27 28 29
namespace paddle {
namespace framework {

class Variable {
 public:
  template <typename T>
  const T& Get() const {
S
sneaxiy 已提交
30 31 32
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
33 34 35 36 37 38 39
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
    PADDLE_ENFORCE_EQ(
        holder_->Type(), VarTypeTrait<T>::kId,
        platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
40 41 42
    return *static_cast<const T*>(holder_->Ptr());
  }

43 44
  bool IsInitialized() const { return holder_ != nullptr; }

Y
Yi Wang 已提交
45 46
  template <typename T>
  T* GetMutable() {
X
Xin Pan 已提交
47
    if (!holder_) {
S
sneaxiy 已提交
48
      holder_.reset(new PlaceholderImpl<T>());
X
Xin Pan 已提交
49
    } else {
50 51 52 53 54
      PADDLE_ENFORCE_EQ(
          holder_->Type(), VarTypeTrait<T>::kId,
          platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
55
    }
Y
Yi Wang 已提交
56
    return static_cast<T*>(holder_->Ptr());
Y
Yi Wang 已提交
57 58
  }

Y
Yu Yang 已提交
59 60
  template <typename T>
  bool IsType() const {
S
sneaxiy 已提交
61
    return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
Y
Yu Yang 已提交
62 63
  }

Y
Yu Yang 已提交
64 65
  void Clear() { holder_.reset(); }

S
sneaxiy 已提交
66
  int Type() const {
67 68
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
69 70 71
    return holder_->Type();
  }

72 73 74
 private:
  // This method hides type T, so it doesn't appear as a template parameter of
  // Variable.
75
  pten::DenseTensor::InplaceVersion* InplaceVersionCounter();
76 77

 public:
78
  void SetInplaceVersionToZero();
79 80 81
  uint32_t CurrentInplaceVersion();
  void BumpInplaceVersion();

Y
Yi Wang 已提交
82 83
 private:
  struct Placeholder {
Z
Zeng Jinle 已提交
84
    virtual ~Placeholder() PADDLE_MAY_THROW {}
S
sneaxiy 已提交
85 86 87 88 89 90

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
S
sneaxiy 已提交
91 92 93 94 95
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

S
sneaxiy 已提交
96 97
    void* ptr_;
    int type_;
Y
Yi Wang 已提交
98 99 100 101 102 103
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
S
sneaxiy 已提交
104 105 106
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
S
sneaxiy 已提交
107
    PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
Y
Yi Wang 已提交
108

S
sneaxiy 已提交
109 110
   private:
    T obj_;
Y
Yi Wang 已提交
111 112
  };

S
sneaxiy 已提交
113
  // pointers to a PlaceholderImpl object indeed.
114
  std::shared_ptr<Placeholder> holder_;
Y
Yi Wang 已提交
115 116
};

117 118
inline pten::DenseTensor::InplaceVersion* Variable::InplaceVersionCounter() {
  pten::DenseTensor::InplaceVersion* version_counter_ptr(nullptr);
119 120 121 122 123 124 125
  if (IsType<framework::LoDTensor>()) {
    version_counter_ptr =
        &GetMutable<framework::LoDTensor>()->InplaceVersionCounter();
  } else if (IsType<framework::Tensor>()) {
    version_counter_ptr =
        &GetMutable<framework::Tensor>()->InplaceVersionCounter();

126 127
  } else if (IsType<pten::SelectedRows>()) {
    version_counter_ptr = &GetMutable<pten::SelectedRows>()
128 129 130 131 132 133 134 135 136 137
                               ->mutable_value()
                               ->InplaceVersionCounter();
  } else {
    VLOG(4) << "Only supports Tensor, LoDTensor, SelectedRows to have "
               "TensorInplaceVersion, but received type "
            << platform::demangle(framework::ToTypeName(Type()));
  }
  return version_counter_ptr;
}

138 139 140 141 142 143
inline void Variable::SetInplaceVersionToZero() {
  auto inplace_version_counter = this->InplaceVersionCounter();
  if (inplace_version_counter)
    inplace_version_counter->SetInplaceVersionToZero();
}

144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
inline uint32_t Variable::CurrentInplaceVersion() {
  auto version_counter_ptr = InplaceVersionCounter();
  if (version_counter_ptr) {
    return version_counter_ptr->CurrentVersion();
  } else {
    return 0;
  }
}

inline void Variable::BumpInplaceVersion() {
  auto version_counter_ptr = InplaceVersionCounter();
  if (version_counter_ptr) {
    return version_counter_ptr->Bump();
  } else {
    VLOG(4) << "Only supports Tensor, LoDTensor, SelectedRows to have "
               "TensorInplaceVersion, but received type "
            << platform::demangle(framework::ToTypeName(Type()));
  }
}
Y
Yi Wang 已提交
163 164
}  // namespace framework
}  // namespace paddle