Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
项目经理老王
Mace
提交
6d12272f
Mace
项目概览
项目经理老王
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6d12272f
编写于
8月 07, 2019
作者:
李
李寅
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'quantize' into 'master'
Fix post quantize See merge request !1169
上级
ba98bf87
d71985fe
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
209 addition
and
129 deletion
+209
-129
mace/python/tools/converter_tool/hexagon_converter.py
mace/python/tools/converter_tool/hexagon_converter.py
+18
-12
mace/tools/mace_run.cc
mace/tools/mace_run.cc
+183
-114
tools/image/tensor_to_image.py
tools/image/tensor_to_image.py
+8
-3
未找到文件。
mace/python/tools/converter_tool/hexagon_converter.py
浏览文件 @
6d12272f
...
@@ -144,18 +144,23 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -144,18 +144,23 @@ class HexagonConverter(base_converter.ConverterInterface):
return
self
.
_model
return
self
.
_model
def
add_port_for_tensors
(
self
,
tensors
):
for
i
in
range
(
len
(
tensors
)):
if
':'
not
in
tensors
[
i
]:
node_name
=
tensors
[
i
]
tensors
[
i
]
+=
':0'
if
node_name
in
self
.
_quantize_activation_info
:
self
.
_quantize_activation_info
[
tensors
[
i
]]
=
\
self
.
_quantize_activation_info
[
node_name
]
def
convert_ops
(
self
):
def
convert_ops
(
self
):
print
(
"Convert mace graph to hexagon."
)
print
(
"Convert mace graph to hexagon."
)
for
op
in
self
.
_model
.
op
:
for
op
in
self
.
_model
.
op
:
if
not
self
.
_hexagon_ops
.
has_op
(
op
.
type
):
if
not
self
.
_hexagon_ops
.
has_op
(
op
.
type
):
raise
Exception
(
'Unsupported op: '
,
op
)
raise
Exception
(
'Unsupported op: '
,
op
)
for
i
in
range
(
len
(
op
.
input
)):
if
':'
not
in
op
.
input
[
i
]:
self
.
add_port_for_tensors
(
op
.
input
)
node_name
=
op
.
input
[
i
]
self
.
add_port_for_tensors
(
op
.
output
)
op
.
input
[
i
]
+=
':0'
if
node_name
in
self
.
_quantize_activation_info
:
self
.
_quantize_activation_info
[
op
.
input
[
i
]]
=
\
self
.
_quantize_activation_info
[
node_name
]
if
op
.
type
==
MaceOp
.
Conv2D
.
name
\
if
op
.
type
==
MaceOp
.
Conv2D
.
name
\
or
op
.
type
==
MaceOp
.
DepthwiseConv2d
.
name
:
or
op
.
type
==
MaceOp
.
DepthwiseConv2d
.
name
:
...
@@ -483,13 +488,15 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -483,13 +488,15 @@ class HexagonConverter(base_converter.ConverterInterface):
for
tensor
in
self
.
_model
.
tensors
:
for
tensor
in
self
.
_model
.
tensors
:
tensor
.
node_id
=
node_id_counter
tensor
.
node_id
=
node_id_counter
node_id_counter
+=
1
node_id_counter
+=
1
tensor_op
,
port
=
get_op_and_port_from_tensor
(
tensor
.
name
)
node_id_map
[
tensor
.
name
]
=
tensor
.
node_id
node_id_map
[
tensor_op
]
=
tensor
.
node_id
print
(
"Hexagon op:"
)
print
(
"Hexagon op:"
)
index
=
0
index
=
0
for
op
in
self
.
_model
.
op
:
for
op
in
self
.
_model
.
op
:
op
.
node_id
=
node_id_counter
op
.
node_id
=
node_id_counter
node_id_counter
+=
1
for
output
in
op
.
output
:
node_id_map
[
output
]
=
op
.
node_id
if
op
.
type
not
in
[
HexagonOp
.
QuantizeINPUT_f_to_8
,
if
op
.
type
not
in
[
HexagonOp
.
QuantizeINPUT_f_to_8
,
HexagonOp
.
DequantizeOUTPUT_8tof
.
name
]:
HexagonOp
.
DequantizeOUTPUT_8tof
.
name
]:
index_str
=
str
(
index
)
index_str
=
str
(
index
)
...
@@ -498,11 +505,10 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -498,11 +505,10 @@ class HexagonConverter(base_converter.ConverterInterface):
index_str
=
''
index_str
=
''
print
(
'Op: %s (%s, node_id:%d, index:%s)'
%
print
(
'Op: %s (%s, node_id:%d, index:%s)'
%
(
op
.
name
,
op
.
type
,
op
.
node_id
,
index_str
))
(
op
.
name
,
op
.
type
,
op
.
node_id
,
index_str
))
node_id_counter
+=
1
node_id_map
[
op
.
name
]
=
op
.
node_id
for
ipt
in
op
.
input
:
for
ipt
in
op
.
input
:
op_name
,
port
=
get_op_and_port_from_tensor
(
ipt
)
op_name
,
port
=
get_op_and_port_from_tensor
(
ipt
)
node_id
=
node_id_map
[
op_name
]
tensor_name
=
ipt
if
port
==
0
else
op_name
+
':0'
node_id
=
node_id_map
[
tensor_name
]
node_input
=
op
.
node_input
.
add
()
node_input
=
op
.
node_input
.
add
()
node_input
.
node_id
=
node_id
node_input
.
node_id
=
node_id
node_input
.
output_port
=
int
(
port
)
node_input
.
output_port
=
int
(
port
)
mace/tools/mace_run.cc
浏览文件 @
6d12272f
...
@@ -24,6 +24,8 @@
...
@@ -24,6 +24,8 @@
* --model_data_file=model_data.data \
* --model_data_file=model_data.data \
* --device=GPU
* --device=GPU
*/
*/
#include <sys/types.h>
#include <dirent.h>
#include <stdint.h>
#include <stdint.h>
#include <cstdio>
#include <cstdio>
#include <cstdlib>
#include <cstdlib>
...
@@ -276,6 +278,7 @@ bool RunModel(const std::string &model_name,
...
@@ -276,6 +278,7 @@ bool RunModel(const std::string &model_name,
std
::
map
<
std
::
string
,
mace
::
MaceTensor
>
inputs
;
std
::
map
<
std
::
string
,
mace
::
MaceTensor
>
inputs
;
std
::
map
<
std
::
string
,
mace
::
MaceTensor
>
outputs
;
std
::
map
<
std
::
string
,
mace
::
MaceTensor
>
outputs
;
std
::
map
<
std
::
string
,
int64_t
>
inputs_size
;
for
(
size_t
i
=
0
;
i
<
input_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
input_count
;
++
i
)
{
// Allocate input and output
// Allocate input and output
// only support float and int32, use char for generalization
// only support float and int32, use char for generalization
...
@@ -283,6 +286,7 @@ bool RunModel(const std::string &model_name,
...
@@ -283,6 +286,7 @@ bool RunModel(const std::string &model_name,
int64_t
input_size
=
int64_t
input_size
=
std
::
accumulate
(
input_shapes
[
i
].
begin
(),
input_shapes
[
i
].
end
(),
4
,
std
::
accumulate
(
input_shapes
[
i
].
begin
(),
input_shapes
[
i
].
end
(),
4
,
std
::
multiplies
<
int64_t
>
());
std
::
multiplies
<
int64_t
>
());
inputs_size
[
input_names
[
i
]]
=
input_size
;
auto
buffer_in
=
std
::
shared_ptr
<
char
>
(
new
char
[
input_size
],
auto
buffer_in
=
std
::
shared_ptr
<
char
>
(
new
char
[
input_size
],
std
::
default_delete
<
char
[]
>
());
std
::
default_delete
<
char
[]
>
());
// load input
// load input
...
@@ -310,90 +314,139 @@ bool RunModel(const std::string &model_name,
...
@@ -310,90 +314,139 @@ bool RunModel(const std::string &model_name,
output_data_formats
[
i
]);
output_data_formats
[
i
]);
}
}
LOG
(
INFO
)
<<
"Warm up run"
;
if
(
!
FLAGS_input_dir
.
empty
())
{
double
warmup_millis
;
DIR
*
dir_parent
;
while
(
true
)
{
struct
dirent
*
entry
;
int64_t
t3
=
NowMicros
();
dir_parent
=
opendir
(
FLAGS_input_dir
.
c_str
());
MaceStatus
warmup_status
=
engine
->
Run
(
inputs
,
&
outputs
);
if
(
dir_parent
)
{
if
(
warmup_status
!=
MaceStatus
::
MACE_SUCCESS
)
{
while
((
entry
=
readdir
(
dir_parent
)))
{
LOG
(
ERROR
)
<<
"Warmup runtime error, retry ... errcode: "
std
::
string
file_name
=
std
::
string
(
entry
->
d_name
);
<<
warmup_status
.
information
();
std
::
string
prefix
=
FormatName
(
input_names
[
0
]);
do
{
if
(
file_name
.
find
(
prefix
)
==
0
)
{
std
::
string
suffix
=
file_name
.
substr
(
prefix
.
size
());
for
(
size_t
i
=
0
;
i
<
input_count
;
++
i
)
{
file_name
=
FLAGS_input_dir
+
"/"
+
FormatName
(
input_names
[
i
])
+
suffix
;
std
::
ifstream
in_file
(
file_name
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
cout
<<
"Read "
<<
file_name
<<
std
::
endl
;
if
(
in_file
.
is_open
())
{
in_file
.
read
(
reinterpret_cast
<
char
*>
(
inputs
[
input_names
[
i
]].
data
().
get
()),
inputs_size
[
input_names
[
i
]]
*
sizeof
(
float
));
in_file
.
close
();
}
else
{
std
::
cerr
<<
"Open input file failed"
<<
std
::
endl
;
return
-
1
;
}
}
engine
->
Run
(
inputs
,
&
outputs
);
if
(
!
FLAGS_output_dir
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
output_count
;
++
i
)
{
std
::
string
output_name
=
FLAGS_output_dir
+
"/"
+
FormatName
(
output_names
[
i
])
+
suffix
;
std
::
ofstream
out_file
(
output_name
,
std
::
ios
::
binary
);
if
(
out_file
.
is_open
())
{
int64_t
output_size
=
std
::
accumulate
(
output_shapes
[
i
].
begin
(),
output_shapes
[
i
].
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
out_file
.
write
(
reinterpret_cast
<
char
*>
(
outputs
[
output_names
[
i
]].
data
().
get
()),
output_size
*
sizeof
(
float
));
out_file
.
flush
();
out_file
.
close
();
}
else
{
std
::
cerr
<<
"Open output file failed"
<<
std
::
endl
;
return
-
1
;
}
}
}
}
}
closedir
(
dir_parent
);
}
else
{
std
::
cerr
<<
"Directory "
<<
FLAGS_input_dir
<<
" does not exist."
<<
std
::
endl
;
}
}
else
{
LOG
(
INFO
)
<<
"Warm up run"
;
double
warmup_millis
;
while
(
true
)
{
int64_t
t3
=
NowMicros
();
MaceStatus
warmup_status
=
engine
->
Run
(
inputs
,
&
outputs
);
if
(
warmup_status
!=
MaceStatus
::
MACE_SUCCESS
)
{
LOG
(
ERROR
)
<<
"Warmup runtime error, retry ... errcode: "
<<
warmup_status
.
information
();
do
{
#ifdef MODEL_GRAPH_FORMAT_CODE
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status
=
create_engine_status
=
CreateMaceEngineFromCode
(
model_name
,
CreateMaceEngineFromCode
(
model_name
,
reinterpret_cast
<
const
unsigned
char
*>
(
reinterpret_cast
<
const
unsigned
char
*>
(
model_weights_data
->
data
()),
model_weights_data
->
data
()),
model_weights_data
->
length
(),
model_weights_data
->
length
(),
input_names
,
input_names
,
output_names
,
output_names
,
config
,
config
,
&
engine
);
&
engine
);
#else
#else
create_engine_status
=
create_engine_status
=
CreateMaceEngineFromProto
(
reinterpret_cast
<
const
unsigned
char
*>
(
CreateMaceEngineFromProto
(
reinterpret_cast
<
const
unsigned
char
*>
(
model_graph_data
->
data
()),
model_graph_data
->
data
()),
model_graph_data
->
length
(),
model_graph_data
->
length
(),
reinterpret_cast
<
const
unsigned
char
*>
(
reinterpret_cast
<
const
unsigned
char
*>
(
model_weights_data
->
data
()),
model_weights_data
->
data
()),
model_weights_data
->
length
(),
model_weights_data
->
length
(),
input_names
,
input_names
,
output_names
,
output_names
,
config
,
config
,
&
engine
);
&
engine
);
#endif
#endif
}
while
(
create_engine_status
!=
MaceStatus
::
MACE_SUCCESS
);
}
while
(
create_engine_status
!=
MaceStatus
::
MACE_SUCCESS
);
}
else
{
}
else
{
int64_t
t4
=
NowMicros
();
int64_t
t4
=
NowMicros
();
warmup_millis
=
(
t4
-
t3
)
/
1000.0
;
warmup_millis
=
(
t4
-
t3
)
/
1000.0
;
LOG
(
INFO
)
<<
"1st warm up run latency: "
<<
warmup_millis
<<
" ms"
;
LOG
(
INFO
)
<<
"1st warm up run latency: "
<<
warmup_millis
<<
" ms"
;
break
;
break
;
}
}
}
}
double
model_run_millis
=
-
1
;
double
model_run_millis
=
-
1
;
benchmark
::
OpStat
op_stat
;
benchmark
::
OpStat
op_stat
;
if
(
FLAGS_round
>
0
)
{
if
(
FLAGS_round
>
0
)
{
LOG
(
INFO
)
<<
"Run model"
;
LOG
(
INFO
)
<<
"Run model"
;
int64_t
total_run_duration
=
0
;
int64_t
total_run_duration
=
0
;
for
(
int
i
=
0
;
i
<
FLAGS_round
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FLAGS_round
;
++
i
)
{
std
::
unique_ptr
<
port
::
Logger
>
info_log
;
std
::
unique_ptr
<
port
::
Logger
>
info_log
;
std
::
unique_ptr
<
port
::
MallocLogger
>
malloc_logger
;
std
::
unique_ptr
<
port
::
MallocLogger
>
malloc_logger
;
if
(
FLAGS_malloc_check_cycle
>=
1
&&
i
%
FLAGS_malloc_check_cycle
==
0
)
{
if
(
FLAGS_malloc_check_cycle
>=
1
info_log
=
LOG_PTR
(
INFO
);
&&
i
%
FLAGS_malloc_check_cycle
==
0
)
{
malloc_logger
=
port
::
Env
::
Default
()
->
NewMallocLogger
(
info_log
=
LOG_PTR
(
INFO
);
info_log
.
get
(),
MakeString
(
i
));
malloc_logger
=
port
::
Env
::
Default
()
->
NewMallocLogger
(
}
info_log
.
get
(),
MakeString
(
i
));
MaceStatus
run_status
;
}
RunMetadata
metadata
;
MaceStatus
run_status
;
RunMetadata
*
metadata_ptr
=
nullptr
;
RunMetadata
metadata
;
if
(
FLAGS_benchmark
)
{
RunMetadata
*
metadata_ptr
=
nullptr
;
metadata_ptr
=
&
metadata
;
if
(
FLAGS_benchmark
)
{
}
metadata_ptr
=
&
metadata
;
}
while
(
true
)
{
while
(
true
)
{
int64_t
t0
=
NowMicros
();
int64_t
t0
=
NowMicros
();
run_status
=
engine
->
Run
(
inputs
,
&
outputs
,
metadata_ptr
);
run_status
=
engine
->
Run
(
inputs
,
&
outputs
,
metadata_ptr
);
if
(
run_status
!=
MaceStatus
::
MACE_SUCCESS
)
{
if
(
run_status
!=
MaceStatus
::
MACE_SUCCESS
)
{
LOG
(
ERROR
)
<<
"Mace run model runtime error, retry ... errcode: "
LOG
(
ERROR
)
<<
"Mace run model runtime error, retry ... errcode: "
<<
run_status
.
information
();
<<
run_status
.
information
();
do
{
do
{
#ifdef MODEL_GRAPH_FORMAT_CODE
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status
=
create_engine_status
=
CreateMaceEngineFromCode
(
model_name
,
CreateMaceEngineFromCode
(
reinterpret_cast
<
const
unsigned
char
*>
(
model_name
,
model_weights_data
->
data
()),
model_weights_data
->
length
(),
input_names
,
output_names
,
config
,
&
engine
);
#else
create_engine_status
=
CreateMaceEngineFromProto
(
reinterpret_cast
<
const
unsigned
char
*>
(
model_graph_data
->
data
()),
model_graph_data
->
length
(),
reinterpret_cast
<
const
unsigned
char
*>
(
reinterpret_cast
<
const
unsigned
char
*>
(
model_weights_data
->
data
()),
model_weights_data
->
data
()),
model_weights_data
->
length
(),
model_weights_data
->
length
(),
...
@@ -401,46 +454,60 @@ bool RunModel(const std::string &model_name,
...
@@ -401,46 +454,60 @@ bool RunModel(const std::string &model_name,
output_names
,
output_names
,
config
,
config
,
&
engine
);
&
engine
);
#else
create_engine_status
=
CreateMaceEngineFromProto
(
reinterpret_cast
<
const
unsigned
char
*>
(
model_graph_data
->
data
()),
model_graph_data
->
length
(),
reinterpret_cast
<
const
unsigned
char
*>
(
model_weights_data
->
data
()),
model_weights_data
->
length
(),
input_names
,
output_names
,
config
,
&
engine
);
#endif
#endif
}
while
(
create_engine_status
!=
MaceStatus
::
MACE_SUCCESS
);
}
while
(
create_engine_status
!=
MaceStatus
::
MACE_SUCCESS
);
}
else
{
}
else
{
int64_t
t1
=
NowMicros
();
int64_t
t1
=
NowMicros
();
total_run_duration
+=
(
t1
-
t0
);
total_run_duration
+=
(
t1
-
t0
);
if
(
FLAGS_benchmark
)
{
if
(
FLAGS_benchmark
)
{
op_stat
.
StatMetadata
(
metadata
);
op_stat
.
StatMetadata
(
metadata
);
}
break
;
}
}
break
;
}
}
}
}
model_run_millis
=
total_run_duration
/
1000.0
/
FLAGS_round
;
LOG
(
INFO
)
<<
"Average latency: "
<<
model_run_millis
<<
" ms"
;
}
}
model_run_millis
=
total_run_duration
/
1000.0
/
FLAGS_round
;
LOG
(
INFO
)
<<
"Average latency: "
<<
model_run_millis
<<
" ms"
;
}
for
(
size_t
i
=
0
;
i
<
output_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_count
;
++
i
)
{
std
::
string
output_name
=
std
::
string
output_name
=
FLAGS_output_file
+
"_"
+
FormatName
(
output_names
[
i
]);
FLAGS_output_file
+
"_"
+
FormatName
(
output_names
[
i
]);
std
::
ofstream
out_file
(
output_name
,
std
::
ios
::
binary
);
std
::
ofstream
out_file
(
output_name
,
std
::
ios
::
binary
);
// only support float and int32
// only support float and int32
int64_t
output_size
=
int64_t
output_size
=
std
::
accumulate
(
output_shapes
[
i
].
begin
(),
output_shapes
[
i
].
end
(),
4
,
std
::
accumulate
(
output_shapes
[
i
].
begin
(),
output_shapes
[
i
].
end
(),
4
,
std
::
multiplies
<
int64_t
>
());
std
::
multiplies
<
int64_t
>
());
out_file
.
write
(
out_file
.
write
(
outputs
[
output_names
[
i
]].
data
<
char
>
().
get
(),
output_size
);
outputs
[
output_names
[
i
]].
data
<
char
>
().
get
(),
output_size
);
out_file
.
flush
();
out_file
.
flush
();
out_file
.
close
();
out_file
.
close
();
LOG
(
INFO
)
<<
"Write output file "
<<
output_name
<<
" with size "
LOG
(
INFO
)
<<
"Write output file "
<<
output_name
<<
" with size "
<<
output_size
<<
" done."
;
<<
output_size
<<
" done."
;
}
}
// Metrics reporting tools depends on the format, keep in consistent
// Metrics reporting tools depends on the format, keep in consistent
printf
(
"========================================================
\n
"
);
printf
(
"========================================================
\n
"
);
printf
(
" capability(CPU) init warmup run_avg
\n
"
);
printf
(
" capability(CPU) init warmup run_avg
\n
"
);
printf
(
"========================================================
\n
"
);
printf
(
"========================================================
\n
"
);
printf
(
"time %15.3f %11.3f %11.3f %11.3f
\n
"
,
printf
(
"time %15.3f %11.3f %11.3f %11.3f
\n
"
,
cpu_capability
,
init_millis
,
warmup_millis
,
model_run_millis
);
cpu_capability
,
init_millis
,
warmup_millis
,
model_run_millis
);
if
(
FLAGS_benchmark
)
{
if
(
FLAGS_benchmark
)
{
op_stat
.
PrintStat
();
op_stat
.
PrintStat
();
}
}
}
return
true
;
return
true
;
...
@@ -514,10 +581,12 @@ int Main(int argc, char **argv) {
...
@@ -514,10 +581,12 @@ int Main(int argc, char **argv) {
output_data_formats
[
i
]
=
ParseDataFormat
(
raw_output_data_formats
[
i
]);
output_data_formats
[
i
]
=
ParseDataFormat
(
raw_output_data_formats
[
i
]);
}
}
float
cpu_float32_performance
=
0.0
f
;
// get cpu capability
if
(
FLAGS_input_dir
.
empty
())
{
Capability
cpu_capability
=
GetCapability
(
DeviceType
::
CPU
);
// get cpu capability
float
cpu_float32_performance
=
cpu_capability
.
float32_performance
.
exec_time
;
Capability
cpu_capability
=
GetCapability
(
DeviceType
::
CPU
);
cpu_float32_performance
=
cpu_capability
.
float32_performance
.
exec_time
;
}
bool
ret
=
false
;
bool
ret
=
false
;
for
(
int
i
=
0
;
i
<
FLAGS_restart_round
;
++
i
)
{
for
(
int
i
=
0
;
i
<
FLAGS_restart_round
;
++
i
)
{
...
...
tools/image/tensor_to_image.py
浏览文件 @
6d12272f
...
@@ -28,16 +28,21 @@ def parse_args():
...
@@ -28,16 +28,21 @@ def parse_args():
"--image_shape"
,
"--image_shape"
,
type
=
str
,
type
=
str
,
help
=
"target image shape, e.g, 224,224,3"
)
help
=
"target image shape, e.g, 224,224,3"
)
parser
.
add_argument
(
"--add_softmax"
,
action
=
"store_true"
,
help
=
"add softmax before convert to image"
)
return
parser
.
parse_known_args
()
return
parser
.
parse_known_args
()
def
tensors_to_images
(
input_files
,
image_shape
):
def
tensors_to_images
(
input_files
,
image_shape
,
add_softmax
):
with
tf
.
Graph
().
as_default
():
with
tf
.
Graph
().
as_default
():
input
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
image_shape
,
name
=
'input'
)
input
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
image_shape
,
name
=
'input'
)
output
=
tf
.
placeholder
(
tf
.
string
,
name
=
'output_file'
)
output
=
tf
.
placeholder
(
tf
.
string
,
name
=
'output_file'
)
if
add_softmax
:
input
=
tf
.
nn
.
softmax
(
input
)
# use the second channel if it is gray image
# use the second channel if it is gray image
if
image_shape
[
2
]
==
2
:
if
image_shape
[
2
]
==
2
:
input
=
tf
.
nn
.
softmax
(
input
)
_
,
input
=
tf
.
split
(
input
,
2
,
axis
=
2
)
_
,
input
=
tf
.
split
(
input
,
2
,
axis
=
2
)
tensor_data
=
tf
.
image
.
convert_image_dtype
(
input
,
tensor_data
=
tf
.
image
.
convert_image_dtype
(
input
,
tf
.
uint8
,
tf
.
uint8
,
...
@@ -68,7 +73,7 @@ def main(unused_args):
...
@@ -68,7 +73,7 @@ def main(unused_args):
input_files
.
append
(
FLAGS
.
input
)
input_files
.
append
(
FLAGS
.
input
)
image_shape
=
[
int
(
dim
)
for
dim
in
FLAGS
.
image_shape
.
split
(
','
)]
image_shape
=
[
int
(
dim
)
for
dim
in
FLAGS
.
image_shape
.
split
(
','
)]
tensors_to_images
(
input_files
,
image_shape
)
tensors_to_images
(
input_files
,
image_shape
,
FLAGS
.
add_softmax
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录