提交 2bb0ac92 编写于 作者: M minqiyang

Polish code

上级 a04d9981
...@@ -107,8 +107,7 @@ def train(use_cuda, train_program, parallel, params_dirname): ...@@ -107,8 +107,7 @@ def train(use_cuda, train_program, parallel, params_dirname):
event_handler=event_handler, event_handler=event_handler,
feed_order=['pixel', 'label']) feed_order=['pixel', 'label'])
if six.PY3: return trainer
del trainer
def infer(use_cuda, inference_program, parallel, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
...@@ -132,12 +131,15 @@ def main(use_cuda, parallel): ...@@ -132,12 +131,15 @@ def main(use_cuda, parallel):
save_path = "image_classification_vgg.inference.model" save_path = "image_classification_vgg.inference.model"
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
train( trainer = train(
use_cuda=use_cuda, use_cuda=use_cuda,
train_program=train_network, train_program=train_network,
params_dirname=save_path, params_dirname=save_path,
parallel=parallel) parallel=parallel)
if six.PY3:
del trainer
# FIXME(zcd): in the inference stage, the number of # FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel. # input data is one, it is not appropriate to use parallel.
if parallel and use_cuda: if parallel and use_cuda:
......
...@@ -90,8 +90,7 @@ def train(use_cuda, train_program, params_dirname, parallel): ...@@ -90,8 +90,7 @@ def train(use_cuda, train_program, params_dirname, parallel):
reader=train_reader, reader=train_reader,
feed_order=['img', 'label']) feed_order=['img', 'label'])
if six.PY3: return trainer
del trainer
def infer(use_cuda, inference_program, parallel, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
...@@ -117,12 +116,15 @@ def main(use_cuda, parallel): ...@@ -117,12 +116,15 @@ def main(use_cuda, parallel):
# call train() with is_local argument to run distributed train # call train() with is_local argument to run distributed train
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
train( trainer = train(
use_cuda=use_cuda, use_cuda=use_cuda,
train_program=train_program, train_program=train_program,
params_dirname=params_dirname, params_dirname=params_dirname,
parallel=parallel) parallel=parallel)
if six.PY3:
del trainer
# FIXME(zcd): in the inference stage, the number of # FIXME(zcd): in the inference stage, the number of
# input data is one, it is not appropriate to use parallel. # input data is one, it is not appropriate to use parallel.
if parallel and use_cuda: if parallel and use_cuda:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册