variable.h 4.8 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yi Wang 已提交
14 15
#pragma once

Y
Yi Wang 已提交
16
#include <memory>
17
#include <string>
Y
Yi Wang 已提交
18
#include <typeindex>
Y
Yi Wang 已提交
19 20
#include <typeinfo>

21
#include "paddle/fluid/framework/selected_rows.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/framework/var_type_traits.h"
Y
Yi Wang 已提交
23 24 25 26 27 28 29
namespace paddle {
namespace framework {

class Variable {
 public:
  template <typename T>
  const T& Get() const {
S
sneaxiy 已提交
30 31 32
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
33 34 35 36 37 38 39
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
    PADDLE_ENFORCE_EQ(
        holder_->Type(), VarTypeTrait<T>::kId,
        platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
40 41 42
    return *static_cast<const T*>(holder_->Ptr());
  }

43 44
  bool IsInitialized() const { return holder_ != nullptr; }

Y
Yi Wang 已提交
45 46
  template <typename T>
  T* GetMutable() {
X
Xin Pan 已提交
47
    if (!holder_) {
S
sneaxiy 已提交
48
      holder_.reset(new PlaceholderImpl<T>());
X
Xin Pan 已提交
49
    } else {
50 51 52 53 54
      PADDLE_ENFORCE_EQ(
          holder_->Type(), VarTypeTrait<T>::kId,
          platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
55
    }
Y
Yi Wang 已提交
56
    return static_cast<T*>(holder_->Ptr());
Y
Yi Wang 已提交
57 58
  }

Y
Yu Yang 已提交
59 60
  template <typename T>
  bool IsType() const {
S
sneaxiy 已提交
61
    return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
Y
Yu Yang 已提交
62 63
  }

Y
Yu Yang 已提交
64 65
  void Clear() { holder_.reset(); }

S
sneaxiy 已提交
66
  int Type() const {
67 68
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
69 70 71
    return holder_->Type();
  }

72 73 74 75 76 77 78 79 80
 private:
  // This method hides type T, so it doesn't appear as a template parameter of
  // Variable.
  framework::TensorInplaceVersion* InplaceVersionCounter();

 public:
  uint32_t CurrentInplaceVersion();
  void BumpInplaceVersion();

Y
Yi Wang 已提交
81 82
 private:
  struct Placeholder {
Z
Zeng Jinle 已提交
83
    virtual ~Placeholder() PADDLE_MAY_THROW {}
S
sneaxiy 已提交
84 85 86 87 88 89

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
S
sneaxiy 已提交
90 91 92 93 94
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

S
sneaxiy 已提交
95 96
    void* ptr_;
    int type_;
Y
Yi Wang 已提交
97 98 99 100 101 102
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
S
sneaxiy 已提交
103 104 105
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
S
sneaxiy 已提交
106
    PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
Y
Yi Wang 已提交
107

S
sneaxiy 已提交
108 109
   private:
    T obj_;
Y
Yi Wang 已提交
110 111
  };

S
sneaxiy 已提交
112
  // pointers to a PlaceholderImpl object indeed.
113
  std::shared_ptr<Placeholder> holder_;
Y
Yi Wang 已提交
114 115
};

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
  framework::TensorInplaceVersion* version_counter_ptr(nullptr);
  if (IsType<framework::LoDTensor>()) {
    version_counter_ptr =
        &GetMutable<framework::LoDTensor>()->InplaceVersionCounter();
  } else if (IsType<framework::Tensor>()) {
    version_counter_ptr =
        &GetMutable<framework::Tensor>()->InplaceVersionCounter();

  } else if (IsType<framework::SelectedRows>()) {
    version_counter_ptr = &GetMutable<framework::SelectedRows>()
                               ->mutable_value()
                               ->InplaceVersionCounter();
  } else {
    VLOG(4) << "Only supports Tensor, LoDTensor, SelectedRows to have "
               "TensorInplaceVersion, but received type "
            << platform::demangle(framework::ToTypeName(Type()));
  }
  return version_counter_ptr;
}

inline uint32_t Variable::CurrentInplaceVersion() {
  auto version_counter_ptr = InplaceVersionCounter();
  if (version_counter_ptr) {
    return version_counter_ptr->CurrentVersion();
  } else {
    return 0;
  }
}

inline void Variable::BumpInplaceVersion() {
  auto version_counter_ptr = InplaceVersionCounter();
  if (version_counter_ptr) {
    return version_counter_ptr->Bump();
  } else {
    VLOG(4) << "Only supports Tensor, LoDTensor, SelectedRows to have "
               "TensorInplaceVersion, but received type "
            << platform::demangle(framework::ToTypeName(Type()));
  }
}
Y
Yi Wang 已提交
156 157
}  // namespace framework
}  // namespace paddle