提交 aac42b9a 编写于 作者: S superjomn

update

上级 e5b563e6
...@@ -4,9 +4,14 @@ cc_library(memory_lite SRCS memory.cc) ...@@ -4,9 +4,14 @@ cc_library(memory_lite SRCS memory.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc)
add_subdirectory(x86) add_subdirectory(x86)
add_subdirectory(cuda) add_subdirectory(cuda)
add_subdirectory(operators) add_subdirectory(operators)
add_subdirectory(kernels) add_subdirectory(kernels)
add_subdirectory(model_parser) add_subdirectory(model_parser)
# tests
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void* TargetMalloc(TargetType target, size_t size) { static void* TargetMalloc(TargetType target, size_t size) {
void* data{nullptr}; void* data{nullptr};
switch (static_cast<int>(target)) { switch (static_cast<int>(target)) {
case static_cast<int>(TargetType::kX86): case static_cast<int>(TargetType::kX86):
...@@ -40,7 +40,7 @@ void* TargetMalloc(TargetType target, size_t size) { ...@@ -40,7 +40,7 @@ void* TargetMalloc(TargetType target, size_t size) {
return data; return data;
} }
void TargetFree(TargetType target, void* data) { static void TargetFree(TargetType target, void* data) {
switch (static_cast<int>(target)) { switch (static_cast<int>(target)) {
case static_cast<int>(TargetType::kX86): case static_cast<int>(TargetType::kX86):
TargetWrapper<TARGET(kX86)>::Free(data); TargetWrapper<TARGET(kX86)>::Free(data);
...@@ -59,6 +59,7 @@ void TargetFree(TargetType target, void* data) { ...@@ -59,6 +59,7 @@ void TargetFree(TargetType target, void* data) {
// Memory buffer manager. // Memory buffer manager.
class Buffer { class Buffer {
public: public:
Buffer() = default;
Buffer(TargetType target, size_t size) : space_(size), target_(target) {} Buffer(TargetType target, size_t size) : space_(size), target_(target) {}
void* data() const { return data_; } void* data() const { return data_; }
......
cc_library(model_parser_lite SRCS model_parser.cc) cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite)
cc_library(runtime_lite SRCS runtime.cc) cc_library(runtime_lite SRCS runtime.cc)
cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite)
...@@ -12,8 +12,137 @@ ...@@ -12,8 +12,137 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// #include "paddle/fluid/lite/model_parser/model_parser.h"
// Created by chunwei on 19-2-25. #include <fstream>
// #include "paddle/fluid/lite/scope.h"
#include "paddle/fluid/lite/tensor.h"
#include "paddle/fluid/lite/variable.h"
namespace paddle {
namespace lite {
int SizeOfType(framework::proto::VarType::Type type) {
using Type = framework::proto::VarType::Type;
switch (static_cast<int>(type)) {
#define DO(desc, type) \
case Type::VarType_Type_##desc: \
return sizeof(type);
DO(BOOL, bool);
DO(FP16, float);
DO(FP32, float);
DO(INT8, int8_t);
DO(INT32, int);
DO(INT64, int64_t);
#undef DO
default:
LOG(FATAL) << "unknown data type";
}
}
void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
using Type = framework::proto::VarType::Type;
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
CHECK_EQ(version, 0U) << "Only version 0 is supported";
// read tensor desc
framework::proto::VarType::TensorDesc desc;
{
// int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
CHECK(desc.ParseFromArray(buf.get(), size)) << "Cannot parse tensor desc";
}
// read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(dims);
void *buf;
size_t size = product(tensor->dims()) * SizeOfType(desc.data_type());
// alllocate memory
switch (static_cast<int>(desc.data_type())) {
#define DO(desc, type) \
case Type::VarType_Type_##desc: \
buf = tensor->mutable_data<type>(); \
break;
DO(BOOL, bool);
DO(FP32, float);
DO(INT8, int8_t);
DO(INT16, int16_t);
DO(INT32, int32_t);
DO(INT64, int64_t);
#undef DO
default:
LOG(FATAL) << "unknown type";
}
is.read(static_cast<char *>(buf), size);
}
void LoadLoDTensor(std::istream &is, Variable *var) {
auto *tensor = var->GetMutable<lite::Tensor>();
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
LOG(INFO) << "model version " << version;
// Load LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
TensorFromStream(is, tensor);
}
// TODO(Superjomn) support SelectedRows.
void ReadBinaryFile(const std::string &filename, std::string *contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
CHECK(fin.is_open()) << "Cannot open file " << filename;
fin.seekg(0, std::ios::end);
auto size = fin.tellg();
contents->clear();
contents->resize(size);
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
std::unique_ptr<framework::proto::ProgramDesc> LoadProgram(
const std::string &path) {
std::string desc_str;
ReadBinaryFile(path, &desc_str);
std::unique_ptr<framework::proto::ProgramDesc> main_program(
new framework::proto::ProgramDesc);
main_program->ParseFromString(desc_str);
return main_program;
}
void LoadParams(const std::string &path) {}
void LoadModel(const std::string &model_dir, Scope *scope) {
const std::string prog_path = model_dir + "/__model__";
auto prog = LoadProgram(prog_path);
auto main_block = prog->blocks(0);
for (auto &var : main_block.vars()) {
std::string file_path = model_dir + "/" + var.name();
std::ifstream file(file_path);
LoadLoDTensor(file, scope->Var(var.name()));
}
}
#include "model_parser.h" } // namespace lite
} // namespace paddle
...@@ -15,15 +15,22 @@ ...@@ -15,15 +15,22 @@
// This file contains model format related operations, such as load a model, // This file contains model format related operations, such as load a model,
// parse an operator definitions and so on. // parse an operator definitions and so on.
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void LoadProgram(const std::string& path); // Read a __model__ file.
std::unique_ptr<framework::proto::ProgramDesc> LoadProgram(
const std::string& path);
// Read a single file containing all the parameters.
void LoadParams(const std::string& path); void LoadParams(const std::string& path);
// Read a model and files of parameters.
void LoadModel(const std::string& model_dir); void LoadModel(const std::string& model_dir);
} // namespace lite } // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
TEST(ModelParser, LoadProgram) {
auto program = LoadProgram(
"/home/chunwei/project2/models/fc/fluid_checkpoint/__model__");
}
} // namespace lite
} // namespace paddle
...@@ -64,9 +64,10 @@ class OpLite : public Registry { ...@@ -64,9 +64,10 @@ class OpLite : public Registry {
framework::Scope *scope) = 0; framework::Scope *scope) = 0;
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
virtual void StaticPickKernel(const std::vector<OpTarget> &valid_targets) = 0; virtual void StaticPickKernel(
const std::vector<TargetType> &valid_targets) = 0;
void PickBestKernel(const std::vector<OpTarget> &valid_places, void PickBestKernel(const std::vector<TargetType> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic); KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Create all the kernels for the valid targets. // Create all the kernels for the valid targets.
......
...@@ -49,7 +49,8 @@ class FcOpLite : public OpLite { ...@@ -49,7 +49,8 @@ class FcOpLite : public OpLite {
std::string DebugString() const override { return "fc"; } std::string DebugString() const override { return "fc"; }
void StaticPickKernel(const std::vector<OpTarget>& valid_targets) override {} void StaticPickKernel(const std::vector<TargetType>& valid_targets) override {
}
private: private:
mutable FcParam param_; mutable FcParam param_;
......
...@@ -13,3 +13,47 @@ ...@@ -13,3 +13,47 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/scope.h" #include "paddle/fluid/lite/scope.h"
#include "scope.h"
namespace paddle {
namespace lite {
Scope::~Scope() {}
Scope &Scope::NewScope() const {
kids_.push_back(new Scope);
kids_.back()->parent_ = this;
return *kids_.back();
}
Variable *Scope::Var(const std::string &name) {
auto *var = FindVar(name);
if (var) return var;
// create a new variable.
vars_.emplace(name, std::unique_ptr<Variable>(new Variable));
return vars_[name].get();
}
Variable *Scope::FindVar(const std::string &name) const {
Variable *var{nullptr};
var = FindLocalVar(name);
const Scope *cur_scope = this;
while (!var && cur_scope->parent()) {
cur_scope = cur_scope->parent();
var = cur_scope->FindLocalVar(name);
}
return var;
}
Variable *Scope::FindLocalVar(const std::string &name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
}
return nullptr;
}
} // namespace lite
} // namespace paddle
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/variable.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -30,7 +31,7 @@ class Scope final { ...@@ -30,7 +31,7 @@ class Scope final {
Scope& NewScope() const; Scope& NewScope() const;
Variable* Var(std::string* name = nullptr); Variable* Var(const std::string& name);
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
...@@ -42,6 +43,7 @@ class Scope final { ...@@ -42,6 +43,7 @@ class Scope final {
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr}; const Scope* parent_{nullptr};
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
}; };
} // namespace lite } // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/scope.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
TEST(Scope, Var) {
Scope scope;
auto* x = scope.Var("x");
*x->GetMutable<int>() = 100;
ASSERT_EQ(x->Get<int>(), 100);
}
TEST(Scope, FindVar) {
Scope scope;
ASSERT_FALSE(scope.FindVar("x"));
scope.Var("x");
ASSERT_TRUE(scope.FindVar("x"));
}
} // namespace lite
} // namespace paddle
...@@ -37,25 +37,28 @@ class EventTree { ...@@ -37,25 +37,28 @@ class EventTree {
std::vector<event_t> children_; std::vector<event_t> children_;
}; };
using DDim = std::vector<int>; using DDim = std::vector<int64_t>;
DDim SliceDims(const DDim& dims, int begin, int end) { static DDim SliceDims(const DDim& dims, int begin, int end) {
return DDim(dims.begin() + begin, dims.begin() + end - 1); return DDim(dims.begin() + begin, dims.begin() + end - 1);
} }
int product(const DDim& dims) { static int product(const DDim& dims) {
return std::accumulate(dims.begin(), dims.end(), 1, return std::accumulate(dims.begin(), dims.end(), 1,
[](int a, int b) { return a * b; }); [](int a, int b) { return a * b; });
} }
DDim flatten_to_2d(const DDim& dims, int col) { static DDim flatten_to_2d(const DDim& dims, int col) {
return DDim({product(SliceDims(dims, 0, col)), return DDim({product(SliceDims(dims, 0, col)),
product(SliceDims(dims, col, dims.size()))}); product(SliceDims(dims, col, dims.size()))});
} }
using LoD = std::vector<std::vector<size_t>>;
// A light-weight tensor implementation. // A light-weight tensor implementation.
class Tensor { class Tensor {
public: public:
void SyncEventTree(); void SyncEventTree();
Tensor() = default;
template <typename T> template <typename T>
const T* data() const { const T* data() const {
...@@ -66,6 +69,9 @@ class Tensor { ...@@ -66,6 +69,9 @@ class Tensor {
const DDim& dims() const { return dims_; } const DDim& dims() const { return dims_; }
const LoD& lod() { return lod_; }
LoD* mutable_lod() { return &lod_; }
template <typename T> template <typename T>
T* mutable_data() { T* mutable_data() {
buffer_.ResetLazy(target_, product(dims_)); buffer_.ResetLazy(target_, product(dims_));
...@@ -78,6 +84,7 @@ class Tensor { ...@@ -78,6 +84,7 @@ class Tensor {
TargetType target_{TargetType::kHost}; TargetType target_{TargetType::kHost};
DDim dims_; DDim dims_;
Buffer buffer_; Buffer buffer_;
LoD lod_;
}; };
} // namespace lite } // namespace lite
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "paddle/fluid/lite/utils/any.h" #include "paddle/fluid/lite/utils/varient.h"
#include "paddle/fluid/lite/utils/check.h" #include "paddle/fluid/lite/utils/check.h"
#include "paddle/fluid/lite/utils/factory.h" #include "paddle/fluid/lite/utils/factory.h"
#include "paddle/fluid/lite/utils/macros.h" #include "paddle/fluid/lite/utils/macros.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <exception>
#include <memory>
#include <type_traits>
#include <typeinfo>
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
// These code references https://gist.github.com/shoooe/9202235
namespace paddle {
namespace lite {
class any;
template <class Type>
Type any_cast(any&);
template <class Type>
Type any_cast(const any&);
template <class Type>
Type* any_cast(any*);
template <class Type>
const Type* any_cast(const any*);
struct bad_any_cast : public std::bad_cast {};
class any {
public:
template <class Type>
friend Type any_cast(any&);
template <class Type>
friend Type any_cast(const any&);
template <class Type>
friend Type* any_cast(any*);
template <class Type>
friend const Type* any_cast(const any*);
any() : ptr(nullptr) {}
explicit any(any&& x) : ptr(std::move(x.ptr)) {}
explicit any(const any& x) {
if (x.ptr) ptr = x.ptr->clone();
}
template <class Type>
explicit any(const Type& x)
: ptr(new concrete<typename std::decay<const Type>::type>(x)) {}
any& operator=(any&& rhs) {
ptr = std::move(rhs.ptr);
return (*this);
}
any& operator=(const any& rhs) {
ptr = std::move(any(rhs).ptr);
return (*this);
}
template <class T>
any& operator=(T&& x) {
ptr.reset(new concrete<typename std::decay<T>::type>(
typename std::decay<T>::type(x)));
return (*this);
}
template <class T>
any& operator=(const T& x) {
ptr.reset(new concrete<typename std::decay<T>::type>(
typename std::decay<T>::type(x)));
return (*this);
}
void clear() { ptr.reset(nullptr); }
bool empty() const { return ptr == nullptr; }
const std::type_info& type() const {
return (!empty()) ? ptr->type() : typeid(void);
}
private:
struct placeholder {
virtual std::unique_ptr<placeholder> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual ~placeholder() {}
};
template <class T>
struct concrete : public placeholder {
explicit concrete(T&& x) : value(std::move(x)) {}
explicit concrete(const T& x) : value(x) {}
virtual std::unique_ptr<placeholder> clone() const override {
return std::unique_ptr<placeholder>(new concrete<T>(value));
}
virtual const std::type_info& type() const override { return typeid(T); }
T value;
};
std::unique_ptr<placeholder> ptr;
};
template <class Type>
Type any_cast(any& val) {
if (val.ptr->type() != typeid(Type)) throw bad_any_cast();
return static_cast<any::concrete<Type>*>(val.ptr.get())->value;
}
template <class Type>
Type any_cast(const any& val) {
return any_cast<Type>(any(val));
}
template <class Type>
Type* any_cast(any* ptr) {
return dynamic_cast<Type*>(ptr->ptr.get());
}
template <class Type>
const Type* any_cast(const any* ptr) {
return dynamic_cast<const Type*>(ptr->ptr.get());
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <exception>
#include <memory>
#include <type_traits>
#include <typeinfo>
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
// These code references https://gist.github.com/shoooe/9202235
namespace paddle {
namespace lite {
template <size_t arg1, size_t... others>
struct static_max;
template <size_t arg>
struct static_max<arg> {
static const size_t value = arg;
};
template <size_t arg1, size_t arg2, size_t... others>
struct static_max<arg1, arg2, others...> {
static const size_t value = arg1 >= arg2 ? static_max<arg1, others...>::value
: static_max<arg2, others...>::value;
};
template <typename... Ts>
struct variant_helper;
template <typename F, typename... Ts>
struct variant_helper<F, Ts...> {
inline static void destroy(size_t id, void* data) {
if (id == typeid(F).hash_code())
reinterpret_cast<F*>(data)->~F();
else
variant_helper<Ts...>::destroy(id, data);
}
inline static void move(size_t old_t, void* old_v, void* new_v) {
if (old_t == typeid(F).hash_code())
new (new_v) F(std::move(*reinterpret_cast<F*>(old_v)));
else
variant_helper<Ts...>::move(old_t, old_v, new_v);
}
inline static void copy(size_t old_t, const void* old_v, void* new_v) {
if (old_t == typeid(F).hash_code())
new (new_v) F(*reinterpret_cast<const F*>(old_v));
else
variant_helper<Ts...>::copy(old_t, old_v, new_v);
}
};
template <>
struct variant_helper<> {
inline static void destroy(size_t id, void* data) {}
inline static void move(size_t old_t, void* old_v, void* new_v) {}
inline static void copy(size_t old_t, const void* old_v, void* new_v) {}
};
template <typename... Ts>
struct variant {
private:
static const size_t data_size = static_max<sizeof(Ts)...>::value;
static const size_t data_align = static_max<alignof(Ts)...>::value;
using data_t = typename std::aligned_storage<data_size, data_align>::type;
using helper_t = variant_helper<Ts...>;
static inline size_t invalid_type() { return typeid(void).hash_code(); }
size_t type_id;
data_t data;
public:
variant() : type_id(invalid_type()) {}
variant(const variant<Ts...>& old) : type_id(old.type_id) {
helper_t::copy(old.type_id, &old.data, &data);
}
variant(variant<Ts...>&& old) : type_id(old.type_id) {
helper_t::move(old.type_id, &old.data, &data);
}
// Serves as both the move and the copy asignment operator.
variant<Ts...>& operator=(variant<Ts...> old) {
std::swap(type_id, old.type_id);
std::swap(data, old.data);
return *this;
}
template <typename T>
void is() {
return (type_id == typeid(T).hash_code());
}
size_t type() { return type_id; }
void valid() { return (type_id != invalid_type()); }
template <typename T, typename... Args>
void set(Args&&... args) {
// First we destroy the current contents
helper_t::destroy(type_id, &data);
new (&data) T(std::forward<Args>(args)...);
type_id = typeid(T).hash_code();
}
template <typename T>
T& get() {
// It is a dynamic_cast-like behaviour
if (type_id == typeid(T).hash_code())
return *reinterpret_cast<T*>(&data);
else
throw std::bad_cast();
}
~variant() { helper_t::destroy(type_id, &data); }
};
} // namespace lite
} // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/lite/tensor.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
...@@ -21,17 +22,23 @@ namespace lite { ...@@ -21,17 +22,23 @@ namespace lite {
class Variable { class Variable {
public: public:
template <typename T> template <typename T>
T& Get() { const T& Get() {
return blob_; return blob_.get<T>();
} }
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
return any_cast<T>(&blob_); blob_.set<T>();
return &blob_.get<T>();
}
template <typename T>
bool IsType() {
return blob_.type() == typeid(T).hash_code();
} }
private: private:
any blob_; variant<int, float, std::string, Tensor> blob_;
}; };
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册