未验证 提交 31e2fd99 编写于 作者: B Bai Yifan 提交者: GitHub

some fix about CE (#1242)

* ce fix
上级 d51ebfee
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# This file is only used for continuous evaluation. # This file is only used for continuous evaluation.
export FLAGS_cudnn_deterministic=True export FLAGS_cudnn_deterministic=True
BATCH_SIZE=56
cudaid=${object_detection_cudaid:=0} cudaid=${object_detection_cudaid:=0}
export CUDA_VISIBLE_DEVICES=$cudaid export CUDA_VISIBLE_DEVICES=$cudaid
python train.py --batch_size=64 --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py python train.py --batch_size=${BATCH_SIZE} --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py
cudaid=${object_detection_cudaid_m:=0, 1, 2, 3} cudaid=${object_detection_cudaid_m:=0, 1, 2, 3}
export CUDA_VISIBLE_DEVICES=$cudaid export CUDA_VISIBLE_DEVICES=$cudaid
python train.py --batch_size=64 --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py python train.py --batch_size=${BATCH_SIZE} --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py
...@@ -292,24 +292,24 @@ def train(settings, ...@@ -292,24 +292,24 @@ def train(settings,
shuffle=True, shuffle=True,
use_multiprocessing=True, use_multiprocessing=True,
num_workers=8, num_workers=8,
max_queue=24): max_queue=24,
enable_ce=False):
file_list = os.path.join(settings.data_dir, file_list) file_list = os.path.join(settings.data_dir, file_list)
def infinite_reader(gen):
while True:
for data in gen():
yield data
if 'coco' in settings.dataset: if 'coco' in settings.dataset:
generator = coco(settings, file_list, "train", batch_size, shuffle) generator = coco(settings, file_list, "train", batch_size, shuffle)
else: else:
generator = pascalvoc(settings, file_list, "train", batch_size, shuffle) generator = pascalvoc(settings, file_list, "train", batch_size, shuffle)
def infinite_reader():
while True:
for data in generator():
yield data
def reader(): def reader():
try: try:
enqueuer = GeneratorEnqueuer( enqueuer = GeneratorEnqueuer(
infinite_reader(generator), infinite_reader(), use_multiprocessing=use_multiprocessing)
use_multiprocessing=use_multiprocessing)
enqueuer.start(max_queue_size=max_queue, workers=num_workers) enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_output = None generator_output = None
while True: while True:
...@@ -325,7 +325,10 @@ def train(settings, ...@@ -325,7 +325,10 @@ def train(settings,
if enqueuer is not None: if enqueuer is not None:
enqueuer.stop() enqueuer.stop()
return reader if enable_ce:
return infinite_reader
else:
return reader
def test(settings, file_list, batch_size): def test(settings, file_list, batch_size):
......
...@@ -32,7 +32,7 @@ add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model") ...@@ -32,7 +32,7 @@ add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model")
train_parameters = { train_parameters = {
"pascalvoc": { "pascalvoc": {
"train_images": 19200, "train_images": 16551,
"image_shape": [3, 300, 300], "image_shape": [3, 300, 300],
"class_num": 21, "class_num": 21,
"batch_size": 64, "batch_size": 64,
...@@ -143,7 +143,6 @@ def train(args, ...@@ -143,7 +143,6 @@ def train(args,
startup_prog.random_seed = 111 startup_prog.random_seed = 111
train_prog.random_seed = 111 train_prog.random_seed = 111
test_prog.random_seed = 111 test_prog.random_seed = 111
num_workers = 1
train_py_reader, loss = build_program( train_py_reader, loss = build_program(
main_prog=train_prog, main_prog=train_prog,
...@@ -170,14 +169,14 @@ def train(args, ...@@ -170,14 +169,14 @@ def train(args,
if parallel: if parallel:
train_exe = fluid.ParallelExecutor(main_program=train_prog, train_exe = fluid.ParallelExecutor(main_program=train_prog,
use_cuda=use_gpu, loss_name=loss.name) use_cuda=use_gpu, loss_name=loss.name)
train_reader = reader.train(data_args, train_reader = reader.train(data_args,
train_file_list, train_file_list,
batch_size_per_device, batch_size_per_device,
shuffle=is_shuffle, shuffle=is_shuffle,
use_multiprocessing=True, use_multiprocessing=True,
num_workers=num_workers, num_workers=num_workers,
max_queue=24) max_queue=24,
enable_ce=enable_ce)
test_reader = reader.test(data_args, val_file_list, batch_size) test_reader = reader.test(data_args, val_file_list, batch_size)
train_py_reader.decorate_paddle_reader(train_reader) train_py_reader.decorate_paddle_reader(train_reader)
test_py_reader.decorate_paddle_reader(test_reader) test_py_reader.decorate_paddle_reader(test_reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册