Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
4f717825
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4f717825
编写于
6月 05, 2020
作者:
L
luxuhui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: fix input tensor dtype error
N/A Signed-off-by:
N
Luxuhui
<
luxuhui@xiaomi.com
>
上级
80835bbe
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
36 addition
and
19 deletion
+36
-19
mace/libmace/mace.cc
mace/libmace/mace.cc
+36
-19
未找到文件。
mace/libmace/mace.cc
浏览文件 @
4f717825
...
@@ -453,6 +453,8 @@ class MaceEngine::Impl {
...
@@ -453,6 +453,8 @@ class MaceEngine::Impl {
MaceStatus
TransposeOutput
(
const
Tensor
*
output_tensor
,
MaceStatus
TransposeOutput
(
const
Tensor
*
output_tensor
,
std
::
pair
<
const
std
::
string
,
MaceTensor
>
*
output
);
std
::
pair
<
const
std
::
string
,
MaceTensor
>
*
output
);
Tensor
*
CreateInputTensor
(
const
std
::
string
&
input_name
,
DataType
input_dt
);
private:
private:
std
::
unique_ptr
<
port
::
ReadOnlyMemoryRegion
>
model_data_
;
std
::
unique_ptr
<
port
::
ReadOnlyMemoryRegion
>
model_data_
;
std
::
unique_ptr
<
OpRegistry
>
op_registry_
;
std
::
unique_ptr
<
OpRegistry
>
op_registry_
;
...
@@ -554,6 +556,20 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
...
@@ -554,6 +556,20 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
MACE_CHECK_NOTNULL
(
device_
);
MACE_CHECK_NOTNULL
(
device_
);
}
}
Tensor
*
MaceEngine
::
Impl
::
CreateInputTensor
(
const
std
::
string
&
input_name
,
DataType
input_dt
)
{
Tensor
*
input_tensor
=
nullptr
;
if
(
input_dt
==
DT_FLOAT
&&
(
net_data_type_
==
DT_BFLOAT16
||
net_data_type_
==
DT_FLOAT16
))
{
input_tensor
=
ws_
->
CreateTensor
(
input_name
,
device_
->
allocator
(),
net_data_type_
);
}
else
{
input_tensor
=
ws_
->
CreateTensor
(
input_name
,
device_
->
allocator
(),
input_dt
);
}
return
input_tensor
;
}
MaceStatus
MaceEngine
::
Impl
::
Init
(
MaceStatus
MaceEngine
::
Impl
::
Init
(
const
NetDef
*
net_def
,
const
NetDef
*
net_def
,
const
std
::
vector
<
std
::
string
>
&
input_nodes
,
const
std
::
vector
<
std
::
string
>
&
input_nodes
,
...
@@ -584,8 +600,7 @@ MaceStatus MaceEngine::Impl::Init(
...
@@ -584,8 +600,7 @@ MaceStatus MaceEngine::Impl::Init(
<<
MakeString
(
MapKeys
(
input_info_map_
));
<<
MakeString
(
MapKeys
(
input_info_map_
));
}
}
DataType
input_dt
=
input_info_map_
[
input_name
].
data_type
();
DataType
input_dt
=
input_info_map_
[
input_name
].
data_type
();
Tensor
*
input_tensor
=
Tensor
*
input_tensor
=
CreateInputTensor
(
input_name
,
input_dt
);
ws_
->
CreateTensor
(
input_name
,
device_
->
allocator
(),
input_dt
);
// Resize to possible largest shape to avoid resize during running.
// Resize to possible largest shape to avoid resize during running.
std
::
vector
<
index_t
>
shape
(
input_info_map_
[
input_name
].
dims_size
());
std
::
vector
<
index_t
>
shape
(
input_info_map_
[
input_name
].
dims_size
());
for
(
int
i
=
0
;
i
<
input_info_map_
[
input_name
].
dims_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
input_info_map_
[
input_name
].
dims_size
();
++
i
)
{
...
@@ -765,8 +780,12 @@ MaceStatus MaceEngine::Impl::TransposeInput(
...
@@ -765,8 +780,12 @@ MaceStatus MaceEngine::Impl::TransposeInput(
input
.
second
.
shape
(),
input
.
second
.
shape
(),
dst_dims
,
dst_dims
,
input_data
);
input_data
);
#ifdef MACE_ENABLE_BFLOAT16
}
else
{
}
else
if
(
net_data_type_
==
DT_BFLOAT16
)
{
LOG
(
FATAL
)
<<
"Invalid net data type: "
<<
net_data_type_
;
}
#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro
}
else
if
(
input_dt
==
DataType
::
DT_FLOAT16
||
input_dt
==
DataType
::
DT_BFLOAT16
)
{
auto
*
input_data
=
input_tensor
->
mutable_data
<
BFloat16
>
();
auto
*
input_data
=
input_tensor
->
mutable_data
<
BFloat16
>
();
return
ops
::
Transpose
(
thread_pool_
.
get
(),
return
ops
::
Transpose
(
thread_pool_
.
get
(),
input
.
second
.
data
<
float
>
().
get
(),
input
.
second
.
data
<
float
>
().
get
(),
...
@@ -774,9 +793,6 @@ MaceStatus MaceEngine::Impl::TransposeInput(
...
@@ -774,9 +793,6 @@ MaceStatus MaceEngine::Impl::TransposeInput(
dst_dims
,
dst_dims
,
input_data
);
input_data
);
#endif // MACE_ENABLE_BFLOAT16
#endif // MACE_ENABLE_BFLOAT16
}
else
{
LOG
(
FATAL
)
<<
"Invalid net data type: "
<<
net_data_type_
;
}
}
else
if
(
input_dt
==
DataType
::
DT_INT32
)
{
}
else
if
(
input_dt
==
DataType
::
DT_INT32
)
{
auto
input_data
=
input_tensor
->
mutable_data
<
int
>
();
auto
input_data
=
input_tensor
->
mutable_data
<
int
>
();
return
ops
::
Transpose
(
thread_pool_
.
get
(),
return
ops
::
Transpose
(
thread_pool_
.
get
(),
...
@@ -800,17 +816,18 @@ MaceStatus MaceEngine::Impl::TransposeInput(
...
@@ -800,17 +816,18 @@ MaceStatus MaceEngine::Impl::TransposeInput(
auto
input_data
=
input_tensor
->
mutable_data
<
float
>
();
auto
input_data
=
input_tensor
->
mutable_data
<
float
>
();
memcpy
(
input_data
,
input
.
second
.
data
().
get
(),
memcpy
(
input_data
,
input
.
second
.
data
().
get
(),
input_tensor
->
size
()
*
sizeof
(
float
));
input_tensor
->
size
()
*
sizeof
(
float
));
#ifdef MACE_ENABLE_BFLOAT16
}
else
{
}
else
if
(
net_data_type_
==
DataType
::
DT_BFLOAT16
)
{
LOG
(
FATAL
)
<<
"Invalid net data type: "
<<
net_data_type_
;
}
#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro
}
else
if
(
input_dt
==
DataType
::
DT_FLOAT16
||
input_dt
==
DataType
::
DT_BFLOAT16
)
{
auto
input_data
=
input_tensor
->
mutable_data
<
BFloat16
>
();
auto
input_data
=
input_tensor
->
mutable_data
<
BFloat16
>
();
const
float
*
data
=
input
.
second
.
data
().
get
();
const
float
*
data
=
input
.
second
.
data
().
get
();
for
(
index_t
i
=
0
;
i
<
input_tensor
->
size
();
++
i
)
{
for
(
index_t
i
=
0
;
i
<
input_tensor
->
size
();
++
i
)
{
input_data
[
i
]
=
data
[
i
];
input_data
[
i
]
=
data
[
i
];
}
}
#endif // MACE_ENABLE_BFLOAT16
#endif // MACE_ENABLE_BFLOAT16
}
else
{
LOG
(
FATAL
)
<<
"Invalid net data type: "
<<
net_data_type_
;
}
}
else
if
(
input_dt
==
DataType
::
DT_INT32
)
{
}
else
if
(
input_dt
==
DataType
::
DT_INT32
)
{
auto
input_data
=
input_tensor
->
mutable_data
<
int
>
();
auto
input_data
=
input_tensor
->
mutable_data
<
int
>
();
memcpy
(
input_data
,
input
.
second
.
data
().
get
(),
memcpy
(
input_data
,
input
.
second
.
data
().
get
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录