variable.h 3.1 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yi Wang 已提交
14 15
#pragma once

Y
Yi Wang 已提交
16
#include <memory>
17
#include <string>
Y
Yi Wang 已提交
18
#include <typeindex>
Y
Yi Wang 已提交
19 20
#include <typeinfo>

S
sneaxiy 已提交
21
#include "paddle/fluid/framework/var_type_traits.h"
Y
Yi Wang 已提交
22

Y
Yi Wang 已提交
23 24 25 26 27 28 29
namespace paddle {
namespace framework {

class Variable {
 public:
  template <typename T>
  const T& Get() const {
S
sneaxiy 已提交
30 31 32
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
33 34 35 36 37 38 39
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
    PADDLE_ENFORCE_EQ(
        holder_->Type(), VarTypeTrait<T>::kId,
        platform::errors::InvalidArgument(
            "The Variable type must be %s, but the type it holds is %s.",
            ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
40 41 42
    return *static_cast<const T*>(holder_->Ptr());
  }

43 44
  bool IsInitialized() const { return holder_ != nullptr; }

Y
Yi Wang 已提交
45 46
  template <typename T>
  T* GetMutable() {
X
Xin Pan 已提交
47
    if (!holder_) {
S
sneaxiy 已提交
48
      holder_.reset(new PlaceholderImpl<T>());
X
Xin Pan 已提交
49
    } else {
50 51 52 53 54
      PADDLE_ENFORCE_EQ(
          holder_->Type(), VarTypeTrait<T>::kId,
          platform::errors::InvalidArgument(
              "The Variable type must be %s, but the type it holds is %s.",
              ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type())));
Y
Yi Wang 已提交
55
    }
Y
Yi Wang 已提交
56
    return static_cast<T*>(holder_->Ptr());
Y
Yi Wang 已提交
57 58
  }

Y
Yu Yang 已提交
59 60
  template <typename T>
  bool IsType() const {
S
sneaxiy 已提交
61
    return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
Y
Yu Yang 已提交
62 63
  }

Y
Yu Yang 已提交
64 65
  void Clear() { holder_.reset(); }

S
sneaxiy 已提交
66
  int Type() const {
67 68
    PADDLE_ENFORCE_NOT_NULL(
        holder_, platform::errors::NotFound("Variable is not initialized."));
69 70 71
    return holder_->Type();
  }

Y
Yi Wang 已提交
72 73
 private:
  struct Placeholder {
Z
Zeng Jinle 已提交
74
    virtual ~Placeholder() PADDLE_MAY_THROW {}
S
sneaxiy 已提交
75 76 77 78 79 80

    inline int Type() const { return type_; }
    inline const void* Ptr() const { return ptr_; }
    inline void* Ptr() { return ptr_; }

   protected:
S
sneaxiy 已提交
81 82 83 84 85
    inline void Init(void* p, int type) {
      ptr_ = p;
      type_ = type;
    }

S
sneaxiy 已提交
86 87
    void* ptr_;
    int type_;
Y
Yi Wang 已提交
88 89 90 91 92 93
  };

  // Placeholder hides type T, so it doesn't appear as a template
  // parameter of Variable.
  template <typename T>
  struct PlaceholderImpl : public Placeholder {
S
sneaxiy 已提交
94 95 96
    static_assert(
        IsRegisteredVarType<T>(),
        "Not registered type. Please register T inside var_type_traits.h");
S
sneaxiy 已提交
97
    PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
Y
Yi Wang 已提交
98

S
sneaxiy 已提交
99 100
   private:
    T obj_;
Y
Yi Wang 已提交
101 102
  };

S
sneaxiy 已提交
103 104
  // pointers to a PlaceholderImpl object indeed.
  std::unique_ptr<Placeholder> holder_;
Y
Yi Wang 已提交
105 106 107 108
};

}  // namespace framework
}  // namespace paddle