Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
455639b2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
455639b2
编写于
1月 31, 2018
作者:
T
Tao Luo
提交者:
GitHub
1月 31, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7874 from Xreki/core_add_inference_unittest
Change the inference example to an unittest
上级
e1611eb4
f5990b46
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
189 addition
and
43 deletion
+189
-43
cmake/generic.cmake
cmake/generic.cmake
+10
-4
paddle/framework/executor.cc
paddle/framework/executor.cc
+0
-3
paddle/framework/feed_fetch_type.h
paddle/framework/feed_fetch_type.h
+4
-0
paddle/framework/program_desc.cc
paddle/framework/program_desc.cc
+1
-3
paddle/framework/program_desc.h
paddle/framework/program_desc.h
+1
-0
paddle/inference/CMakeLists.txt
paddle/inference/CMakeLists.txt
+2
-15
paddle/inference/io.cc
paddle/inference/io.cc
+5
-4
paddle/inference/io.h
paddle/inference/io.h
+0
-5
paddle/inference/tests/book/CMakeLists.txt
paddle/inference/tests/book/CMakeLists.txt
+7
-0
paddle/inference/tests/book/test_inference_recognize_digits.cc
...e/inference/tests/book/test_inference_recognize_digits.cc
+113
-0
paddle/testing/paddle_gtest_main.cc
paddle/testing/paddle_gtest_main.cc
+3
-1
python/paddle/v2/fluid/tests/book/test_recognize_digits.py
python/paddle/v2/fluid/tests/book/test_recognize_digits.py
+43
-8
未找到文件。
cmake/generic.cmake
浏览文件 @
455639b2
...
...
@@ -224,12 +224,18 @@ function(cc_test TARGET_NAME)
if
(
WITH_TESTING
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS DEPS
)
set
(
multiValueArgs SRCS DEPS
ARGS
)
cmake_parse_arguments
(
cc_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
add_executable
(
${
TARGET_NAME
}
${
cc_test_SRCS
}
)
target_link_libraries
(
${
TARGET_NAME
}
${
cc_test_DEPS
}
paddle_gtest_main paddle_memory gtest gflags
)
# Support linking flags: --whole-archive (Linux) / -force_load (MacOS)
target_circle_link_libraries
(
${
TARGET_NAME
}
${
cc_test_DEPS
}
paddle_gtest_main paddle_memory gtest gflags
)
if
(
"
${
cc_test_DEPS
}
"
MATCHES
"ARCHIVE_START"
)
list
(
REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END
)
endif
()
add_dependencies
(
${
TARGET_NAME
}
${
cc_test_DEPS
}
paddle_gtest_main paddle_memory gtest gflags
)
add_test
(
NAME
${
TARGET_NAME
}
COMMAND
${
TARGET_NAME
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
)
add_test
(
NAME
${
TARGET_NAME
}
COMMAND
${
TARGET_NAME
}
${
cc_test_ARGS
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
)
endif
()
endfunction
(
cc_test
)
...
...
@@ -457,7 +463,7 @@ endfunction()
function
(
py_test TARGET_NAME
)
if
(
WITH_TESTING
)
set
(
options
STATIC static SHARED shared
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS DEPS ARGS
)
cmake_parse_arguments
(
py_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
...
...
paddle/framework/executor.cc
浏览文件 @
455639b2
...
...
@@ -33,9 +33,6 @@ DEFINE_bool(check_nan_inf, false,
namespace
paddle
{
namespace
framework
{
const
std
::
string
kFeedOpType
=
"feed"
;
const
std
::
string
kFetchOpType
=
"fetch"
;
Executor
::
Executor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
static
void
CreateTensor
(
Variable
*
var
,
proto
::
VarDesc
::
VarType
var_type
)
{
...
...
paddle/framework/feed_fetch_type.h
浏览文件 @
455639b2
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/framework/lod_tensor.h"
...
...
@@ -20,5 +21,8 @@ namespace paddle {
namespace
framework
{
using
FeedFetchType
=
LoDTensor
;
using
FeedFetchList
=
std
::
vector
<
FeedFetchType
>
;
static
const
std
::
string
kFeedOpType
=
"feed"
;
static
const
std
::
string
kFetchOpType
=
"fetch"
;
}
// namespace framework
}
// namespace paddle
paddle/framework/program_desc.cc
浏览文件 @
455639b2
...
...
@@ -14,13 +14,11 @@ limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/feed_fetch_type.h"
namespace
paddle
{
namespace
framework
{
const
std
::
string
kFeedOpType
=
"feed"
;
const
std
::
string
kFetchOpType
=
"fetch"
;
BlockDesc
*
ProgramDesc
::
AppendBlock
(
const
BlockDesc
&
parent
)
{
auto
*
b
=
desc_
.
add_blocks
();
b
->
set_parent_idx
(
parent
.
ID
());
...
...
paddle/framework/program_desc.h
浏览文件 @
455639b2
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/platform/macros.h"
...
...
paddle/inference/CMakeLists.txt
浏览文件 @
455639b2
...
...
@@ -24,19 +24,6 @@ if(NOT WITH_C_API AND WITH_FLUID)
install
(
TARGETS paddle_fluid_shared DESTINATION lib
)
endif
()
add_executable
(
example example.cc
)
if
(
APPLE
)
set
(
OPTIONAL_LINK_FLAGS
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
STREQUAL
"Clang"
OR
"
${
CMAKE_CXX_COMPILER_ID
}
"
STREQUAL
"AppleClang"
)
set
(
OPTIONAL_LINK_FLAGS
"-undefined dynamic_lookup"
)
endif
()
target_link_libraries
(
example
-Wl,-force_load paddle_fluid
${
OPTIONAL_LINK_FLAGS
}
${
PTOOLS_LIB
}
)
else
()
target_link_libraries
(
example
-Wl,--start-group -Wl,--whole-archive paddle_fluid
-Wl,--no-whole-archive -Wl,--end-group
${
PTOOLS_LIB
}
)
if
(
WITH_TESTING
)
add_subdirectory
(
tests/book
)
endif
()
paddle/inference/io.cc
浏览文件 @
455639b2
...
...
@@ -13,13 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/inference/io.h"
#include <fstream>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/feed_fetch_type.h"
namespace
paddle
{
namespace
inference
{
const
std
::
string
kFeedOpType
=
"feed"
;
bool
IsParameter
(
const
framework
::
VarDesc
*
var
,
const
framework
::
ProgramDesc
&
main_program
)
{
if
(
var
->
Persistable
())
{
...
...
@@ -27,7 +28,7 @@ bool IsParameter(const framework::VarDesc* var,
for
(
size_t
i
=
0
;
i
<
main_program
.
Size
();
++
i
)
{
const
framework
::
BlockDesc
&
block
=
main_program
.
Block
(
i
);
for
(
auto
*
op
:
block
.
AllOps
())
{
if
(
op
->
Type
()
==
kFeedOpType
)
{
if
(
op
->
Type
()
==
framework
::
kFeedOpType
)
{
continue
;
}
for
(
auto
input_argument_name
:
op
->
InputArgumentNames
())
{
...
...
@@ -51,7 +52,7 @@ void LoadPersistables(framework::Executor& executor,
framework
::
BlockDesc
*
load_block
=
load_program
->
MutableBlock
(
0
);
for
(
auto
*
var
:
global_block
.
AllVars
())
{
if
(
IsParameter
(
var
,
main_program
))
{
LOG
(
INFO
)
<<
"parameter's name: "
<<
var
->
Name
();
VLOG
(
3
)
<<
"parameter's name: "
<<
var
->
Name
();
framework
::
VarDesc
*
new_var
=
load_block
->
Var
(
var
->
Name
());
new_var
->
SetShape
(
var
->
Shape
());
...
...
paddle/inference/io.h
浏览文件 @
455639b2
...
...
@@ -17,18 +17,13 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/var_desc.h"
namespace
paddle
{
namespace
inference
{
bool
IsParameter
(
const
framework
::
VarDesc
*
var
,
const
framework
::
ProgramDesc
&
main_program
);
void
LoadPersistables
(
framework
::
Executor
&
executor
,
framework
::
Scope
&
scope
,
const
std
::
string
&
dirname
,
...
...
paddle/inference/tests/book/CMakeLists.txt
0 → 100644
浏览文件 @
455639b2
set
(
PYTHON_TESTS_DIR
${
PADDLE_SOURCE_DIR
}
/python/paddle/v2/fluid/tests
)
cc_test
(
test_inference_recognize_digits_mlp
SRCS test_inference_recognize_digits.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/recognize_digits_mlp.inference.model
)
set_tests_properties
(
test_inference_recognize_digits_mlp
PROPERTIES DEPENDS test_recognize_digits_mlp_cpu
)
paddle/inference/
example
.cc
→
paddle/inference/
tests/book/test_inference_recognize_digits
.cc
浏览文件 @
455639b2
...
...
@@ -12,93 +12,102 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <time.h>
#include <
io
stream>
#include <
s
stream>
#include "gflags/gflags.h"
#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
if
(
FLAGS_dirname
.
empty
())
{
// Example:
// ./example --dirname=recognize_digits_mlp.inference.model
std
::
cout
<<
"Usage: ./example --dirname=path/to/your/model"
<<
std
::
endl
;
exit
(
1
);
}
// 1. Define place, executor, scope
auto
place
=
paddle
::
platform
::
CPUPlace
();
paddle
::
framework
::
InitDevices
();
auto
*
executor
=
new
paddle
::
framework
::
Executor
(
place
);
template
<
typename
Place
,
typename
T
>
void
TestInference
(
const
std
::
string
&
dirname
,
const
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_feeds
,
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_fetchs
)
{
// 1. Define place, executor and scope
auto
place
=
Place
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
auto
*
scope
=
new
paddle
::
framework
::
Scope
();
std
::
cout
<<
"FLAGS_dirname: "
<<
FLAGS_dirname
<<
std
::
endl
;
std
::
string
dirname
=
FLAGS_dirname
;
// 2. Initialize the inference program
auto
inference_program
=
paddle
::
inference
::
Load
(
*
executor
,
*
scope
,
dirname
);
// 2. Initialize the inference_program and load all parameters from file
auto
inference_program
=
paddle
::
inference
::
Load
(
executor
,
*
scope
,
dirname
);
// 3. Optional: perform optimization on the inference_program
// 4. Get the feed_target_names and fetch_target_names
// 3. Get the feed_target_names and fetch_target_names
const
std
::
vector
<
std
::
string
>&
feed_target_names
=
inference_program
->
GetFeedTargetNames
();
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
inference_program
->
GetFetchTargetNames
();
// 5. Generate input
paddle
::
framework
::
LoDTensor
input
;
srand
(
time
(
0
));
float
*
input_ptr
=
input
.
mutable_data
<
float
>
({
1
,
784
},
paddle
::
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
784
;
++
i
)
{
input_ptr
[
i
]
=
rand
()
/
(
static_cast
<
float
>
(
RAND_MAX
));
}
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
feeds
;
feeds
.
push_back
(
input
);
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
fetchs
;
// Set up maps for feed and fetch targets
// 4. Prepare inputs: set up maps for feed targets
std
::
map
<
std
::
string
,
const
paddle
::
framework
::
LoDTensor
*>
feed_targets
;
std
::
map
<
std
::
string
,
paddle
::
framework
::
LoDTensor
*>
fetch_targets
;
// set_feed_variable
for
(
size_t
i
=
0
;
i
<
feed_target_names
.
size
();
++
i
)
{
feed_targets
[
feed_target_names
[
i
]]
=
&
feeds
[
i
];
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets
[
feed_target_names
[
i
]]
=
cpu_feeds
[
i
];
}
//
get_fetch_variable
fetchs
.
resize
(
fetch_target_names
.
size
())
;
//
5. Define Tensor to get the outputs: set up maps for fetch targets
std
::
map
<
std
::
string
,
paddle
::
framework
::
LoDTensor
*>
fetch_targets
;
for
(
size_t
i
=
0
;
i
<
fetch_target_names
.
size
();
++
i
)
{
fetch_targets
[
fetch_target_names
[
i
]]
=
&
fetchs
[
i
];
fetch_targets
[
fetch_target_names
[
i
]]
=
cpu_
fetchs
[
i
];
}
// Run the inference program
executor
->
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
//
6.
Run the inference program
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
// Get outputs
for
(
size_t
i
=
0
;
i
<
fetchs
.
size
();
++
i
)
{
auto
dims_i
=
fetchs
[
i
].
dims
();
std
::
cout
<<
"dims_i:"
;
for
(
int
j
=
0
;
j
<
dims_i
.
size
();
++
j
)
{
std
::
cout
<<
" "
<<
dims_i
[
j
];
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"result:"
;
float
*
output_ptr
=
fetchs
[
i
].
data
<
float
>
();
for
(
int
j
=
0
;
j
<
paddle
::
framework
::
product
(
dims_i
);
++
j
)
{
std
::
cout
<<
" "
<<
output_ptr
[
j
];
}
std
::
cout
<<
std
::
endl
;
delete
scope
;
}
TEST
(
inference
,
recognize_digits
)
{
if
(
FLAGS_dirname
.
empty
())
{
LOG
(
FATAL
)
<<
"Usage: ./example --dirname=path/to/your/model"
;
}
delete
scope
;
delete
executor
;
LOG
(
INFO
)
<<
"FLAGS_dirname: "
<<
FLAGS_dirname
<<
std
::
endl
;
std
::
string
dirname
=
FLAGS_dirname
;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
return
0
;
paddle
::
framework
::
LoDTensor
input
;
srand
(
time
(
0
));
float
*
input_ptr
=
input
.
mutable_data
<
float
>
({
1
,
28
,
28
},
paddle
::
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
784
;
++
i
)
{
input_ptr
[
i
]
=
rand
()
/
(
static_cast
<
float
>
(
RAND_MAX
));
}
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>
cpu_feeds
;
cpu_feeds
.
push_back
(
&
input
);
paddle
::
framework
::
LoDTensor
output1
;
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>
cpu_fetchs1
;
cpu_fetchs1
.
push_back
(
&
output1
);
// Run inference on CPU
TestInference
<
paddle
::
platform
::
CPUPlace
,
float
>
(
dirname
,
cpu_feeds
,
cpu_fetchs1
);
LOG
(
INFO
)
<<
output1
.
dims
();
#ifdef PADDLE_WITH_CUDA
paddle
::
framework
::
LoDTensor
output2
;
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>
cpu_fetchs2
;
cpu_fetchs2
.
push_back
(
&
output2
);
// Run inference on CUDA GPU
TestInference
<
paddle
::
platform
::
CUDAPlace
,
float
>
(
dirname
,
cpu_feeds
,
cpu_fetchs2
);
LOG
(
INFO
)
<<
output2
.
dims
();
EXPECT_EQ
(
output1
.
dims
(),
output2
.
dims
());
EXPECT_EQ
(
output1
.
numel
(),
output2
.
numel
());
float
err
=
1E-3
;
int
count
=
0
;
for
(
int64_t
i
=
0
;
i
<
output1
.
numel
();
++
i
)
{
if
(
fabs
(
output1
.
data
<
float
>
()[
i
]
-
output2
.
data
<
float
>
()[
i
])
>
err
)
{
count
++
;
}
}
EXPECT_EQ
(
count
,
0
)
<<
"There are "
<<
count
<<
" different elements."
;
#endif
}
paddle/testing/paddle_gtest_main.cc
浏览文件 @
455639b2
...
...
@@ -22,7 +22,9 @@ limitations under the License. */
int
main
(
int
argc
,
char
**
argv
)
{
std
::
vector
<
char
*>
new_argv
;
std
::
string
gflags_env
;
new_argv
.
push_back
(
argv
[
0
]);
for
(
int
i
=
0
;
i
<
argc
;
++
i
)
{
new_argv
.
push_back
(
argv
[
i
]);
}
#ifdef PADDLE_WITH_CUDA
new_argv
.
push_back
(
strdup
(
"--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"
));
...
...
python/paddle/v2/fluid/tests/book/test_recognize_digits.py
浏览文件 @
455639b2
...
...
@@ -45,8 +45,9 @@ BATCH_SIZE = 64
def
loss_net
(
hidden
,
label
):
prediction
=
fluid
.
layers
.
fc
(
input
=
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
return
fluid
.
layers
.
mean
(
x
=
loss
),
fluid
.
layers
.
accuracy
(
input
=
prediction
,
label
=
label
)
avg_loss
=
fluid
.
layers
.
mean
(
x
=
loss
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
prediction
,
label
=
label
)
return
prediction
,
avg_loss
,
acc
def
mlp
(
img
,
label
):
...
...
@@ -73,8 +74,7 @@ def conv_net(img, label):
return
loss_net
(
conv_pool_2
,
label
)
def
main
():
args
=
parse_arg
()
def
train
(
args
,
save_dirname
=
None
):
print
(
"recognize digits with args: {0}"
.
format
(
" "
.
join
(
sys
.
argv
[
1
:])))
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
...
...
@@ -91,7 +91,8 @@ def main():
with
pd
.
do
():
img_
=
pd
.
read_input
(
img
)
label_
=
pd
.
read_input
(
label
)
for
o
in
net_conf
(
img_
,
label_
):
prediction
,
avg_loss
,
acc
=
net_conf
(
img_
,
label_
)
for
o
in
[
avg_loss
,
acc
]:
pd
.
write_output
(
o
)
avg_loss
,
acc
=
pd
()
...
...
@@ -99,7 +100,7 @@ def main():
avg_loss
=
fluid
.
layers
.
mean
(
x
=
avg_loss
)
acc
=
fluid
.
layers
.
mean
(
x
=
acc
)
else
:
avg_loss
,
acc
=
net_conf
(
img
,
label
)
prediction
,
avg_loss
,
acc
=
net_conf
(
img
,
label
)
test_program
=
fluid
.
default_main_program
().
clone
()
...
...
@@ -137,7 +138,10 @@ def main():
acc_val
=
numpy
.
array
(
acc_set
).
mean
()
avg_loss_val
=
numpy
.
array
(
avg_loss_set
).
mean
()
if
float
(
acc_val
)
>
0.85
:
# test acc > 85%
exit
(
0
)
if
save_dirname
is
not
None
:
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
"img"
],
[
prediction
],
exe
)
return
else
:
print
(
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'
.
...
...
@@ -145,5 +149,36 @@ def main():
float
(
avg_loss_val
),
float
(
acc_val
)))
def
infer
(
args
,
save_dirname
=
None
):
if
save_dirname
is
None
:
return
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
)
# The input's dimension of conv should be 4-D or 5-D.
tensor_img
=
numpy
.
random
.
rand
(
1
,
1
,
28
,
28
).
astype
(
"float32"
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
print
(
"infer results: "
,
results
[
0
])
if
__name__
==
'__main__'
:
main
()
args
=
parse_arg
()
if
not
args
.
use_cuda
and
not
args
.
parallel
:
save_dirname
=
"recognize_digits_"
+
args
.
nn_type
+
".inference.model"
else
:
save_dirname
=
None
train
(
args
,
save_dirname
)
infer
(
args
,
save_dirname
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录