From aac42b9a9849e1ffe319910675ef7940d6c295f3 Mon Sep 17 00:00:00 2001 From: superjomn Date: Tue, 26 Feb 2019 16:52:23 +0800 Subject: [PATCH] update --- paddle/fluid/lite/CMakeLists.txt | 5 + paddle/fluid/lite/memory.h | 5 +- paddle/fluid/lite/model_parser/CMakeLists.txt | 3 +- .../fluid/lite/model_parser/model_parser.cc | 137 +++++++++++++++++- paddle/fluid/lite/model_parser/model_parser.h | 9 +- .../lite/model_parser/model_parser_test.cc | 27 ++++ paddle/fluid/lite/op_lite.h | 5 +- paddle/fluid/lite/operators/fc_op.h | 3 +- paddle/fluid/lite/scope.cc | 44 ++++++ paddle/fluid/lite/scope.h | 4 +- paddle/fluid/lite/scope_test.cc | 37 +++++ paddle/fluid/lite/tensor.h | 15 +- paddle/fluid/lite/utils/all.h | 2 +- paddle/fluid/lite/utils/any.h | 129 ----------------- paddle/fluid/lite/utils/varient.h | 123 ++++++++++++++++ paddle/fluid/lite/variable.h | 15 +- 16 files changed, 413 insertions(+), 150 deletions(-) create mode 100644 paddle/fluid/lite/model_parser/model_parser_test.cc create mode 100644 paddle/fluid/lite/scope_test.cc delete mode 100644 paddle/fluid/lite/utils/any.h create mode 100644 paddle/fluid/lite/utils/varient.h diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 19c3834e9d..15f2f30ba4 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -4,9 +4,14 @@ cc_library(memory_lite SRCS memory.cc) cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) cc_library(variable_lite SRCS variable.cc) cc_library(op_registry_lite SRCS op_registry.cc) +cc_library(scope_lite SRCS scope.cc) add_subdirectory(x86) add_subdirectory(cuda) add_subdirectory(operators) add_subdirectory(kernels) add_subdirectory(model_parser) + + +# tests +cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) diff --git a/paddle/fluid/lite/memory.h b/paddle/fluid/lite/memory.h index 65801c7bdd..36f247c209 100644 --- a/paddle/fluid/lite/memory.h +++ b/paddle/fluid/lite/memory.h @@ -19,7 +19,7 @@ namespace paddle { namespace lite { -void* TargetMalloc(TargetType target, size_t size) { +static void* TargetMalloc(TargetType target, size_t size) { void* data{nullptr}; switch (static_cast(target)) { case static_cast(TargetType::kX86): @@ -40,7 +40,7 @@ void* TargetMalloc(TargetType target, size_t size) { return data; } -void TargetFree(TargetType target, void* data) { +static void TargetFree(TargetType target, void* data) { switch (static_cast(target)) { case static_cast(TargetType::kX86): TargetWrapper::Free(data); @@ -59,6 +59,7 @@ void TargetFree(TargetType target, void* data) { // Memory buffer manager. class Buffer { public: + Buffer() = default; Buffer(TargetType target, size_t size) : space_(size), target_(target) {} void* data() const { return data_; } diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 7d9a18c40f..78de6f9367 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -1,2 +1,3 @@ -cc_library(model_parser_lite SRCS model_parser.cc) +cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite) cc_library(runtime_lite SRCS runtime.cc) +cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite) diff --git a/paddle/fluid/lite/model_parser/model_parser.cc b/paddle/fluid/lite/model_parser/model_parser.cc index 5a5fd1bcf9..59b518fb28 100644 --- a/paddle/fluid/lite/model_parser/model_parser.cc +++ b/paddle/fluid/lite/model_parser/model_parser.cc @@ -12,8 +12,137 @@ // See the License for the specific language governing permissions and // limitations under the License. -// -// Created by chunwei on 19-2-25. -// +#include "paddle/fluid/lite/model_parser/model_parser.h" +#include +#include "paddle/fluid/lite/scope.h" +#include "paddle/fluid/lite/tensor.h" +#include "paddle/fluid/lite/variable.h" + +namespace paddle { +namespace lite { + +int SizeOfType(framework::proto::VarType::Type type) { + using Type = framework::proto::VarType::Type; + switch (static_cast(type)) { +#define DO(desc, type) \ + case Type::VarType_Type_##desc: \ + return sizeof(type); + DO(BOOL, bool); + DO(FP16, float); + DO(FP32, float); + DO(INT8, int8_t); + DO(INT32, int); + DO(INT64, int64_t); +#undef DO + default: + LOG(FATAL) << "unknown data type"; + } +} + +void TensorFromStream(std::istream &is, lite::Tensor *tensor) { + using Type = framework::proto::VarType::Type; + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + CHECK_EQ(version, 0U) << "Only version 0 is supported"; + // read tensor desc + framework::proto::VarType::TensorDesc desc; + { + // int32_t size + // proto buffer + int32_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::unique_ptr buf(new char[size]); + is.read(reinterpret_cast(buf.get()), size); + CHECK(desc.ParseFromArray(buf.get(), size)) << "Cannot parse tensor desc"; + } + + // read tensor + std::vector dims; + dims.reserve(static_cast(desc.dims().size())); + std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); + tensor->Resize(dims); + void *buf; + size_t size = product(tensor->dims()) * SizeOfType(desc.data_type()); + // alllocate memory + switch (static_cast(desc.data_type())) { +#define DO(desc, type) \ + case Type::VarType_Type_##desc: \ + buf = tensor->mutable_data(); \ + break; + DO(BOOL, bool); + DO(FP32, float); + DO(INT8, int8_t); + DO(INT16, int16_t); + DO(INT32, int32_t); + DO(INT64, int64_t); +#undef DO + default: + LOG(FATAL) << "unknown type"; + } + + is.read(static_cast(buf), size); +} + +void LoadLoDTensor(std::istream &is, Variable *var) { + auto *tensor = var->GetMutable(); + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + LOG(INFO) << "model version " << version; + + // Load LoD information + uint64_t lod_level; + is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); + auto &lod = *tensor->mutable_lod(); + lod.resize(lod_level); + for (uint64_t i = 0; i < lod_level; ++i) { + uint64_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::vector tmp(size / sizeof(size_t)); + is.read(reinterpret_cast(tmp.data()), + static_cast(size)); + lod[i] = tmp; + } + + TensorFromStream(is, tensor); +} + +// TODO(Superjomn) support SelectedRows. + +void ReadBinaryFile(const std::string &filename, std::string *contents) { + std::ifstream fin(filename, std::ios::in | std::ios::binary); + CHECK(fin.is_open()) << "Cannot open file " << filename; + fin.seekg(0, std::ios::end); + auto size = fin.tellg(); + contents->clear(); + contents->resize(size); + fin.seekg(0, std::ios::beg); + fin.read(&(contents->at(0)), contents->size()); + fin.close(); +} + +std::unique_ptr LoadProgram( + const std::string &path) { + std::string desc_str; + ReadBinaryFile(path, &desc_str); + std::unique_ptr main_program( + new framework::proto::ProgramDesc); + main_program->ParseFromString(desc_str); + return main_program; +} + +void LoadParams(const std::string &path) {} + +void LoadModel(const std::string &model_dir, Scope *scope) { + const std::string prog_path = model_dir + "/__model__"; + auto prog = LoadProgram(prog_path); + + auto main_block = prog->blocks(0); + for (auto &var : main_block.vars()) { + std::string file_path = model_dir + "/" + var.name(); + std::ifstream file(file_path); + LoadLoDTensor(file, scope->Var(var.name())); + } +} -#include "model_parser.h" +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/model_parser.h b/paddle/fluid/lite/model_parser/model_parser.h index 75f13fca2a..f65edabb3f 100644 --- a/paddle/fluid/lite/model_parser/model_parser.h +++ b/paddle/fluid/lite/model_parser/model_parser.h @@ -15,15 +15,22 @@ // This file contains model format related operations, such as load a model, // parse an operator definitions and so on. +#include #include #include +#include "paddle/fluid/framework/framework.pb.h" namespace paddle { namespace lite { -void LoadProgram(const std::string& path); +// Read a __model__ file. +std::unique_ptr LoadProgram( + const std::string& path); + +// Read a single file containing all the parameters. void LoadParams(const std::string& path); +// Read a model and files of parameters. void LoadModel(const std::string& model_dir); } // namespace lite diff --git a/paddle/fluid/lite/model_parser/model_parser_test.cc b/paddle/fluid/lite/model_parser/model_parser_test.cc new file mode 100644 index 0000000000..1ce7eb83a5 --- /dev/null +++ b/paddle/fluid/lite/model_parser/model_parser_test.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/model_parser.h" +#include + +namespace paddle { +namespace lite { + +TEST(ModelParser, LoadProgram) { + auto program = LoadProgram( + "/home/chunwei/project2/models/fc/fluid_checkpoint/__model__"); +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/op_lite.h b/paddle/fluid/lite/op_lite.h index cf94ad26c8..2e58dec354 100644 --- a/paddle/fluid/lite/op_lite.h +++ b/paddle/fluid/lite/op_lite.h @@ -64,9 +64,10 @@ class OpLite : public Registry { framework::Scope *scope) = 0; virtual std::string DebugString() const = 0; - virtual void StaticPickKernel(const std::vector &valid_targets) = 0; + virtual void StaticPickKernel( + const std::vector &valid_targets) = 0; - void PickBestKernel(const std::vector &valid_places, + void PickBestKernel(const std::vector &valid_places, KernelStrategy kernel_strategy = KernelStrategy::kStatic); // Create all the kernels for the valid targets. diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index 588b5e7c11..1f07f434f5 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -49,7 +49,8 @@ class FcOpLite : public OpLite { std::string DebugString() const override { return "fc"; } - void StaticPickKernel(const std::vector& valid_targets) override {} + void StaticPickKernel(const std::vector& valid_targets) override { + } private: mutable FcParam param_; diff --git a/paddle/fluid/lite/scope.cc b/paddle/fluid/lite/scope.cc index 2c89c6168a..1c405464e5 100644 --- a/paddle/fluid/lite/scope.cc +++ b/paddle/fluid/lite/scope.cc @@ -13,3 +13,47 @@ // limitations under the License. #include "paddle/fluid/lite/scope.h" +#include "scope.h" + +namespace paddle { +namespace lite { + +Scope::~Scope() {} + +Scope &Scope::NewScope() const { + kids_.push_back(new Scope); + kids_.back()->parent_ = this; + return *kids_.back(); +} + +Variable *Scope::Var(const std::string &name) { + auto *var = FindVar(name); + if (var) return var; + + // create a new variable. + vars_.emplace(name, std::unique_ptr(new Variable)); + return vars_[name].get(); +} + +Variable *Scope::FindVar(const std::string &name) const { + Variable *var{nullptr}; + var = FindLocalVar(name); + const Scope *cur_scope = this; + while (!var && cur_scope->parent()) { + cur_scope = cur_scope->parent(); + var = cur_scope->FindLocalVar(name); + } + + return var; +} + +Variable *Scope::FindLocalVar(const std::string &name) const { + auto it = vars_.find(name); + if (it != vars_.end()) { + return it->second.get(); + } + return nullptr; +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/scope.h b/paddle/fluid/lite/scope.h index 3943c0818a..11b3f23484 100644 --- a/paddle/fluid/lite/scope.h +++ b/paddle/fluid/lite/scope.h @@ -19,6 +19,7 @@ #include #include #include +#include "paddle/fluid/lite/variable.h" namespace paddle { namespace lite { @@ -30,7 +31,7 @@ class Scope final { Scope& NewScope() const; - Variable* Var(std::string* name = nullptr); + Variable* Var(const std::string& name); Variable* FindVar(const std::string& name) const; @@ -42,6 +43,7 @@ class Scope final { // Scope in `kids_` are owned by this class. mutable std::list kids_; const Scope* parent_{nullptr}; + std::unordered_map> vars_; }; } // namespace lite diff --git a/paddle/fluid/lite/scope_test.cc b/paddle/fluid/lite/scope_test.cc new file mode 100644 index 0000000000..3e1c23172b --- /dev/null +++ b/paddle/fluid/lite/scope_test.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/scope.h" +#include + +namespace paddle { +namespace lite { + +TEST(Scope, Var) { + Scope scope; + auto* x = scope.Var("x"); + *x->GetMutable() = 100; + + ASSERT_EQ(x->Get(), 100); +} + +TEST(Scope, FindVar) { + Scope scope; + ASSERT_FALSE(scope.FindVar("x")); + scope.Var("x"); + ASSERT_TRUE(scope.FindVar("x")); +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/tensor.h b/paddle/fluid/lite/tensor.h index 86b625bd23..d95548b191 100644 --- a/paddle/fluid/lite/tensor.h +++ b/paddle/fluid/lite/tensor.h @@ -37,25 +37,28 @@ class EventTree { std::vector children_; }; -using DDim = std::vector; -DDim SliceDims(const DDim& dims, int begin, int end) { +using DDim = std::vector; +static DDim SliceDims(const DDim& dims, int begin, int end) { return DDim(dims.begin() + begin, dims.begin() + end - 1); } -int product(const DDim& dims) { +static int product(const DDim& dims) { return std::accumulate(dims.begin(), dims.end(), 1, [](int a, int b) { return a * b; }); } -DDim flatten_to_2d(const DDim& dims, int col) { +static DDim flatten_to_2d(const DDim& dims, int col) { return DDim({product(SliceDims(dims, 0, col)), product(SliceDims(dims, col, dims.size()))}); } +using LoD = std::vector>; + // A light-weight tensor implementation. class Tensor { public: void SyncEventTree(); + Tensor() = default; template const T* data() const { @@ -66,6 +69,9 @@ class Tensor { const DDim& dims() const { return dims_; } + const LoD& lod() { return lod_; } + LoD* mutable_lod() { return &lod_; } + template T* mutable_data() { buffer_.ResetLazy(target_, product(dims_)); @@ -78,6 +84,7 @@ class Tensor { TargetType target_{TargetType::kHost}; DDim dims_; Buffer buffer_; + LoD lod_; }; } // namespace lite diff --git a/paddle/fluid/lite/utils/all.h b/paddle/fluid/lite/utils/all.h index 79f7e9e0eb..df07541a17 100644 --- a/paddle/fluid/lite/utils/all.h +++ b/paddle/fluid/lite/utils/all.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/fluid/lite/utils/any.h" +#include "paddle/fluid/lite/utils/varient.h" #include "paddle/fluid/lite/utils/check.h" #include "paddle/fluid/lite/utils/factory.h" #include "paddle/fluid/lite/utils/macros.h" diff --git a/paddle/fluid/lite/utils/any.h b/paddle/fluid/lite/utils/any.h deleted file mode 100644 index 20c4a6faad..0000000000 --- a/paddle/fluid/lite/utils/any.h +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include - -// This is an equivalent implementation of boost::any. We implement this to -// avoid including the whole boost library and keep the inference library small. -// These code references https://gist.github.com/shoooe/9202235 - -namespace paddle { -namespace lite { - -class any; -template -Type any_cast(any&); -template -Type any_cast(const any&); -template -Type* any_cast(any*); -template -const Type* any_cast(const any*); -struct bad_any_cast : public std::bad_cast {}; - -class any { - public: - template - friend Type any_cast(any&); - - template - friend Type any_cast(const any&); - - template - friend Type* any_cast(any*); - - template - friend const Type* any_cast(const any*); - - any() : ptr(nullptr) {} - explicit any(any&& x) : ptr(std::move(x.ptr)) {} - - explicit any(const any& x) { - if (x.ptr) ptr = x.ptr->clone(); - } - - template - explicit any(const Type& x) - : ptr(new concrete::type>(x)) {} - any& operator=(any&& rhs) { - ptr = std::move(rhs.ptr); - return (*this); - } - any& operator=(const any& rhs) { - ptr = std::move(any(rhs).ptr); - return (*this); - } - template - any& operator=(T&& x) { - ptr.reset(new concrete::type>( - typename std::decay::type(x))); - return (*this); - } - template - any& operator=(const T& x) { - ptr.reset(new concrete::type>( - typename std::decay::type(x))); - return (*this); - } - void clear() { ptr.reset(nullptr); } - bool empty() const { return ptr == nullptr; } - const std::type_info& type() const { - return (!empty()) ? ptr->type() : typeid(void); - } - - private: - struct placeholder { - virtual std::unique_ptr clone() const = 0; - virtual const std::type_info& type() const = 0; - virtual ~placeholder() {} - }; - - template - struct concrete : public placeholder { - explicit concrete(T&& x) : value(std::move(x)) {} - explicit concrete(const T& x) : value(x) {} - virtual std::unique_ptr clone() const override { - return std::unique_ptr(new concrete(value)); - } - virtual const std::type_info& type() const override { return typeid(T); } - T value; - }; - - std::unique_ptr ptr; -}; - -template -Type any_cast(any& val) { - if (val.ptr->type() != typeid(Type)) throw bad_any_cast(); - return static_cast*>(val.ptr.get())->value; -} -template -Type any_cast(const any& val) { - return any_cast(any(val)); -} -template -Type* any_cast(any* ptr) { - return dynamic_cast(ptr->ptr.get()); -} -template -const Type* any_cast(const any* ptr) { - return dynamic_cast(ptr->ptr.get()); -} - -} // namespace lite -} // namespace paddle diff --git a/paddle/fluid/lite/utils/varient.h b/paddle/fluid/lite/utils/varient.h new file mode 100644 index 0000000000..3a3296976c --- /dev/null +++ b/paddle/fluid/lite/utils/varient.h @@ -0,0 +1,123 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include + +// This is an equivalent implementation of boost::any. We implement this to +// avoid including the whole boost library and keep the inference library small. +// These code references https://gist.github.com/shoooe/9202235 + +namespace paddle { +namespace lite { + +template +struct static_max; +template +struct static_max { + static const size_t value = arg; +}; +template +struct static_max { + static const size_t value = arg1 >= arg2 ? static_max::value + : static_max::value; +}; +template +struct variant_helper; +template +struct variant_helper { + inline static void destroy(size_t id, void* data) { + if (id == typeid(F).hash_code()) + reinterpret_cast(data)->~F(); + else + variant_helper::destroy(id, data); + } + inline static void move(size_t old_t, void* old_v, void* new_v) { + if (old_t == typeid(F).hash_code()) + new (new_v) F(std::move(*reinterpret_cast(old_v))); + else + variant_helper::move(old_t, old_v, new_v); + } + inline static void copy(size_t old_t, const void* old_v, void* new_v) { + if (old_t == typeid(F).hash_code()) + new (new_v) F(*reinterpret_cast(old_v)); + else + variant_helper::copy(old_t, old_v, new_v); + } +}; +template <> +struct variant_helper<> { + inline static void destroy(size_t id, void* data) {} + inline static void move(size_t old_t, void* old_v, void* new_v) {} + inline static void copy(size_t old_t, const void* old_v, void* new_v) {} +}; + +template +struct variant { + private: + static const size_t data_size = static_max::value; + static const size_t data_align = static_max::value; + using data_t = typename std::aligned_storage::type; + using helper_t = variant_helper; + static inline size_t invalid_type() { return typeid(void).hash_code(); } + size_t type_id; + data_t data; + + public: + variant() : type_id(invalid_type()) {} + variant(const variant& old) : type_id(old.type_id) { + helper_t::copy(old.type_id, &old.data, &data); + } + variant(variant&& old) : type_id(old.type_id) { + helper_t::move(old.type_id, &old.data, &data); + } + // Serves as both the move and the copy asignment operator. + variant& operator=(variant old) { + std::swap(type_id, old.type_id); + std::swap(data, old.data); + return *this; + } + template + void is() { + return (type_id == typeid(T).hash_code()); + } + + size_t type() { return type_id; } + + void valid() { return (type_id != invalid_type()); } + + template + void set(Args&&... args) { + // First we destroy the current contents + helper_t::destroy(type_id, &data); + new (&data) T(std::forward(args)...); + type_id = typeid(T).hash_code(); + } + template + T& get() { + // It is a dynamic_cast-like behaviour + if (type_id == typeid(T).hash_code()) + return *reinterpret_cast(&data); + else + throw std::bad_cast(); + } + ~variant() { helper_t::destroy(type_id, &data); } +}; + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/variable.h b/paddle/fluid/lite/variable.h index 83747b786f..bcff0aeef8 100644 --- a/paddle/fluid/lite/variable.h +++ b/paddle/fluid/lite/variable.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/fluid/lite/tensor.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { @@ -21,17 +22,23 @@ namespace lite { class Variable { public: template - T& Get() { - return blob_; + const T& Get() { + return blob_.get(); } template T* GetMutable() { - return any_cast(&blob_); + blob_.set(); + return &blob_.get(); + } + + template + bool IsType() { + return blob_.type() == typeid(T).hash_code(); } private: - any blob_; + variant blob_; }; } // namespace lite -- GitLab