variable.h 2.9 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yi Wang 已提交
14 15
#pragma once

Y
Yi Wang 已提交
16
#include <memory>
17
#include <string>
Y
Yi Wang 已提交
18
#include <typeindex>
Y
Yi Wang 已提交
19 20
#include <typeinfo>

S
sneaxiy 已提交
21
#include "paddle/fluid/framework/var_type_traits.h"
Y
Yi Wang 已提交
22

Y
Yi Wang 已提交
23 24 25 26 27 28 29
namespace paddle {
namespace framework {

class Variable {
 public:
  template <typename T>
  const T& Get() const {
S
sneaxiy 已提交
30 31 32
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
Y
Yu Yang 已提交
33
    PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
S
sneaxiy 已提交
34
    PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
Y
Yu Yang 已提交
35
                   "Variable must be type %s, the holding type is %s",
S
sneaxiy 已提交
36 37
                   ToTypeName(VarTypeTrait<T>::kId),
                   ToTypeName(holder_->Type()));
Y
Yi Wang 已提交
38 39 40
    return *static_cast<const T*>(holder_->Ptr());
  }

41 42
  bool IsInitialized() const { return holder_ != nullptr; }

Y
Yi Wang 已提交
43 44
  template <typename T>
  T* GetMutable() {
X
Xin Pan 已提交
45
    if (!holder_) {
S
sneaxiy 已提交
46
      holder_.reset(new PlaceholderImpl<T>());
X
Xin Pan 已提交
47
    } else {
S
sneaxiy 已提交
48
      PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
X
Xin Pan 已提交
49
                     "Variable must be type %s, the holding type is %s",
S
sneaxiy 已提交
50 51
                     ToTypeName(VarTypeTrait<T>::kId),
                     ToTypeName(holder_->Type()));
Y
Yi Wang 已提交
52
    }
Y
Yi Wang 已提交
53
    return static_cast<T*>(holder_->Ptr());
Y
Yi Wang 已提交
54 55
  }

Y
Yu Yang 已提交
56 57
  template <typename T>
  bool IsType() const {
S
sneaxiy 已提交
58
    return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
Y
Yu Yang 已提交
59 60
  }

Y
Yu Yang 已提交
61 62
  void Clear() { holder_.reset(); }

S
sneaxiy 已提交
63
  int Type() const {
64 65 66 67
    PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
    return holder_->Type();
  }

Y
Yi Wang 已提交
68 69
 private:
  struct Placeholder {
S
sneaxiy 已提交
70 71 72 73 74 75 76 77 78 79
    explicit Placeholder(int type) : type_(type) {}
    virtual ~Placeholder() = default;

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
    void* ptr_;
    int type_;
Y
Yi Wang 已提交
80 81 82 83 84 85
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
S
sneaxiy 已提交
86 87 88 89 90 91
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
    PlaceholderImpl() : Placeholder(VarTypeTrait<T>::kId) {
      this->ptr_ = &obj_;
    }
Y
Yi Wang 已提交
92

S
sneaxiy 已提交
93 94
   private:
    T obj_;
Y
Yi Wang 已提交
95 96
  };

S
sneaxiy 已提交
97 98
  // pointers to a PlaceholderImpl object indeed.
  std::unique_ptr<Placeholder> holder_;
Y
Yi Wang 已提交
99 100 101 102
};

}  // namespace framework
}  // namespace paddle