提交 a2a46b56 编写于 作者: M Megvii Engine Team

fix(lite): fix rknn error in lite

GitOrigin-RevId: b66aa1bf73af8c2993c66f52cc45b991a102d0fa
上级 849f0ece
......@@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key(
* other config not inclue in config and networkIO, ParseInfoFunc can fill it
* with the information in json, now support:
* "device_id" : int, default 0
* "number_threads" : size_t, default 1
* "number_threads" : uint32_t, default 1
* "is_inplace_model" : bool, default false
* "use_tensorrt" : bool, default false
*/
......
......@@ -149,28 +149,42 @@ private:
*/
class LITE_API LiteAny {
public:
enum Type {
STRING = 0,
INT32 = 1,
UINT32 = 2,
UINT8 = 3,
INT8 = 4,
INT64 = 5,
UINT64 = 6,
BOOL = 7,
VOID_PTR = 8,
FLOAT = 9,
NONE_SUPPORT = 10,
};
LiteAny() = default;
template <class T>
LiteAny(T value) : m_holder(new AnyHolder<T>(value)) {
m_is_string = std::is_same<std::string, T>();
m_type = get_type<T>();
}
LiteAny(const LiteAny& any) {
m_holder = any.m_holder->clone();
m_is_string = any.is_string();
m_type = any.m_type;
}
LiteAny& operator=(const LiteAny& any) {
m_holder = any.m_holder->clone();
m_is_string = any.is_string();
m_type = any.m_type;
return *this;
}
bool is_string() const { return m_is_string; }
template <class T>
Type get_type() const;
class HolderBase {
public:
virtual ~HolderBase() = default;
virtual std::shared_ptr<HolderBase> clone() = 0;
virtual size_t type_length() const = 0;
};
template <class T>
......@@ -180,7 +194,6 @@ public:
virtual std::shared_ptr<HolderBase> clone() override {
return std::make_shared<AnyHolder>(m_value);
}
virtual size_t type_length() const override { return sizeof(T); }
public:
T m_value;
......@@ -188,14 +201,21 @@ public:
//! if type is miss matching, it will throw
void type_missmatch(size_t expect, size_t get) const;
//! only check the storage type and the visit type length, so it's not safe
template <class T>
T unsafe_cast() const {
if (sizeof(T) != m_holder->type_length()) {
type_missmatch(m_holder->type_length(), sizeof(T));
T safe_cast() const {
if (get_type<T>() != m_type) {
type_missmatch(m_type, get_type<T>());
}
return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value;
}
template <class T>
bool try_cast() const {
if (get_type<T>() == m_type) {
return true;
} else {
return false;
}
}
//! only check the storage type and the visit type length, so it's not safe
void* cast_void_ptr() const {
return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value;
......@@ -203,7 +223,7 @@ public:
private:
std::shared_ptr<HolderBase> m_holder;
bool m_is_string = false;
Type m_type = NONE_SUPPORT;
};
/*********************** special tensor function ***************/
......
......@@ -127,7 +127,8 @@ int LITE_register_parse_info_func(
separate_config_map["device_id"] = device_id;
}
if (nr_threads != 1) {
separate_config_map["nr_threads"] = nr_threads;
separate_config_map["nr_threads"] =
static_cast<uint32_t>(nr_threads);
}
if (is_cpu_inplace_mode != false) {
separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode;
......
......@@ -352,19 +352,19 @@ void NetworkImplDft::load_model(
//! config some flag get from json config file
if (separate_config_map.find("device_id") != separate_config_map.end()) {
set_device_id(separate_config_map["device_id"].unsafe_cast<int>());
set_device_id(separate_config_map["device_id"].safe_cast<int>());
}
if (separate_config_map.find("number_threads") != separate_config_map.end() &&
separate_config_map["number_threads"].unsafe_cast<size_t>() > 1) {
separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
set_cpu_threads_number(
separate_config_map["number_threads"].unsafe_cast<size_t>());
separate_config_map["number_threads"].safe_cast<uint32_t>());
}
if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
separate_config_map["enable_inplace_model"].unsafe_cast<bool>()) {
separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
set_cpu_inplace_mode();
}
if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
separate_config_map["use_tensorrt"].unsafe_cast<bool>()) {
separate_config_map["use_tensorrt"].safe_cast<bool>()) {
use_tensorrt();
}
......
......@@ -84,7 +84,7 @@ bool default_parse_info(
}
if (device_json.contains("number_threads")) {
separate_config_map["number_threads"] =
static_cast<size_t>(device_json["number_threads"]);
static_cast<uint32_t>(device_json["number_threads"]);
}
if (device_json.contains("enable_inplace_model")) {
separate_config_map["enable_inplace_model"] =
......
......@@ -277,10 +277,28 @@ void Tensor::update_from_implement() {
void LiteAny::type_missmatch(size_t expect, size_t get) const {
LITE_THROW(ssprintf(
"The type store in LiteAny is not match the visit type, type of "
"storage length is %zu, type of visit length is %zu.",
"storage enum is %zu, type of visit enum is %zu.",
expect, get));
}
namespace lite {
#define GET_TYPE(ctype, ENUM) \
template <> \
LiteAny::Type LiteAny::get_type<ctype>() const { \
return ENUM; \
}
GET_TYPE(std::string, STRING)
GET_TYPE(int32_t, INT32)
GET_TYPE(uint32_t, UINT32)
GET_TYPE(int8_t, INT8)
GET_TYPE(uint8_t, UINT8)
GET_TYPE(int64_t, INT64)
GET_TYPE(uint64_t, UINT64)
GET_TYPE(float, FLOAT)
GET_TYPE(bool, BOOL)
GET_TYPE(void*, VOID_PTR)
} // namespace lite
std::shared_ptr<Tensor> TensorUtils::concat(
const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device,
int dst_device_id) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册