提交 c6c66954 编写于 作者: T tink2123 提交者: dengkaipeng

add multiprocess

上级 7fa5296a
......@@ -115,7 +115,7 @@ class DataSetReader(object):
image_ids = self.COCO.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.COCO.loadImgs(image_ids))
# imgs = imgs[:8]
imgs = imgs[-8:]
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), \
......@@ -247,9 +247,13 @@ def train(size=416,
interval=10,
pyreader_num=1,
num_workers=16,
max_queue=32):
max_queue=32,
use_multiprocessing=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, random_shape_iter, random_sizes)
if not use_multiprocessing:
return generator
def infinite_reader():
while True:
for data in generator():
......
......@@ -96,11 +96,11 @@ def train():
random_shape_iter = cfg.max_iter - cfg.start_iter - cfg.tune_iter
if cfg.use_pyreader:
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch'])/devices_num, shuffle=True, random_shape_iter=random_shape_iter, random_sizes=random_sizes, interval=10, pyreader_num=devices_num)
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch'])/devices_num, shuffle=True, random_shape_iter=random_shape_iter, random_sizes=random_sizes, interval=10, pyreader_num=devices_num,use_multiprocessing=cfg.use_multiprocess)
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
else:
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch']), shuffle=True, random_shape_iter=random_shape_iter, random_sizes=random_sizes, interval=10)
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch']), shuffle=True, random_shape_iter=random_shape_iter, random_sizes=random_sizes, interval=10,use_multiprocessing=cfg.use_multiprocess)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def save_model(postfix):
......
......@@ -106,6 +106,7 @@ def parse_args():
add_arg('use_pyreader', bool, True, "Use pyreader.")
add_arg('use_profile', bool, False, "Whether use profiler.")
add_arg('start_iter', int, 0, "Start iteration.")
add_arg('use_multiprocess', bool, True, "add multiprocess.")
#SOLVER
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('max_iter', int, 500200, "Iter number.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册