variable.h 2.9 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yi Wang 已提交
14 15
#pragma once

Y
Yi Wang 已提交
16
#include <memory>
17
#include <string>
Y
Yi Wang 已提交
18
#include <typeindex>
Y
Yi Wang 已提交
19 20
#include <typeinfo>

S
sneaxiy 已提交
21
#include "paddle/fluid/framework/var_type_traits.h"
Y
Yi Wang 已提交
22

Y
Yi Wang 已提交
23 24 25 26 27 28 29
namespace paddle {
namespace framework {

class Variable {
 public:
  template <typename T>
  const T& Get() const {
S
sneaxiy 已提交
30 31 32
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
33
    PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
S
sneaxiy 已提交
34
    PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
35
                   "The Variable type must be %s, but the type it holds is %s.",
S
sneaxiy 已提交
36 37
                   ToTypeName(VarTypeTrait<T>::kId),
                   ToTypeName(holder_->Type()));
Y
Yi Wang 已提交
38 39 40
    return *static_cast<const T*>(holder_->Ptr());
  }

41 42
  bool IsInitialized() const { return holder_ != nullptr; }

Y
Yi Wang 已提交
43 44
  template <typename T>
  T* GetMutable() {
X
Xin Pan 已提交
45
    if (!holder_) {
S
sneaxiy 已提交
46
      holder_.reset(new PlaceholderImpl<T>());
X
Xin Pan 已提交
47
    } else {
48 49 50 51
      PADDLE_ENFORCE(
          holder_->Type() == VarTypeTrait<T>::kId,
          "The Variable type must be %s, but the type it holds is %s.",
          ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type()));
Y
Yi Wang 已提交
52
    }
Y
Yi Wang 已提交
53
    return static_cast<T*>(holder_->Ptr());
Y
Yi Wang 已提交
54 55
  }

Y
Yu Yang 已提交
56 57
  template <typename T>
  bool IsType() const {
S
sneaxiy 已提交
58
    return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
Y
Yu Yang 已提交
59 60
  }

Y
Yu Yang 已提交
61 62
  void Clear() { holder_.reset(); }

S
sneaxiy 已提交
63
  int Type() const {
64
    PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
65 66 67
    return holder_->Type();
  }

Y
Yi Wang 已提交
68 69
 private:
  struct Placeholder {
S
sneaxiy 已提交
70 71 72 73 74 75 76
    virtual ~Placeholder() = default;

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
S
sneaxiy 已提交
77 78 79 80 81
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

S
sneaxiy 已提交
82 83
    void* ptr_;
    int type_;
Y
Yi Wang 已提交
84 85 86 87 88 89
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
S
sneaxiy 已提交
90 91 92
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
S
sneaxiy 已提交
93
    PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
Y
Yi Wang 已提交
94

S
sneaxiy 已提交
95 96
   private:
    T obj_;
Y
Yi Wang 已提交
97 98
  };

S
sneaxiy 已提交
99 100
  // pointers to a PlaceholderImpl object indeed.
  std::unique_ptr<Placeholder> holder_;
Y
Yi Wang 已提交
101 102 103 104
};

}  // namespace framework
}  // namespace paddle