提交 e6c2cb83 编写于 作者: Z zhengya01

add ce

上级 3d8f79d6
...@@ -7,11 +7,11 @@ export OMP_NUM_THREADS=1 ...@@ -7,11 +7,11 @@ export OMP_NUM_THREADS=1
cudaid=${face_detection:=0} # use 0-th card as default cudaid=${face_detection:=0} # use 0-th card as default
export CUDA_VISIBLE_DEVICES=$cudaid export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python train.py --model_save_dir=output/ --data_dir=dataset/coco/ --max_iter=20 --enable_ce | python _ce.py FLAGS_benchmark=true python train.py --model_save_dir=output/ --data_dir=dataset/coco/ --max_iter=20 --enable_ce --pretrained_model=./imagenet_resnet50_fusebn | python _ce.py
cudaid=${face_detection_m:=0,1,2,3} # use 0,1,2,3 card as default cudaid=${face_detection_m:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python train.py --model_save_dir=output/ --data_dir=dataset/coco/ --max_iter=20 --enable_ce | python _ce.py FLAGS_benchmark=true python train.py --model_save_dir=output/ --data_dir=dataset/coco/ --max_iter=20 --enable_ce --pretrained_model=./imagenet_resnet50_fusebn | python _ce.py
...@@ -46,11 +46,14 @@ def train(): ...@@ -46,11 +46,14 @@ def train():
devices_num = len(devices.split(",")) devices_num = len(devices.split(","))
total_batch_size = devices_num * cfg.TRAIN.im_per_batch total_batch_size = devices_num * cfg.TRAIN.im_per_batch
use_random = True
if cfg.enable_ce:
use_random = False
model = model_builder.FasterRCNN( model = model_builder.FasterRCNN(
add_conv_body_func=resnet.add_ResNet50_conv4_body, add_conv_body_func=resnet.add_ResNet50_conv4_body,
add_roi_box_head_func=resnet.add_ResNet_roi_conv5_head, add_roi_box_head_func=resnet.add_ResNet_roi_conv5_head,
use_pyreader=cfg.use_pyreader, use_pyreader=cfg.use_pyreader,
use_random=True) use_random=use_random)
model.build_model(image_shape) model.build_model(image_shape)
loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss = model.loss() loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss = model.loss()
loss_cls.persistable = True loss_cls.persistable = True
...@@ -92,16 +95,19 @@ def train(): ...@@ -92,16 +95,19 @@ def train():
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=bool(cfg.use_gpu), loss_name=loss.name) use_cuda=bool(cfg.use_gpu), loss_name=loss.name)
shuffle = True
if cfg.enable_ce:
shuffle = False
if cfg.use_pyreader: if cfg.use_pyreader:
train_reader = reader.train( train_reader = reader.train(
batch_size=cfg.TRAIN.im_per_batch, batch_size=cfg.TRAIN.im_per_batch,
total_batch_size=total_batch_size, total_batch_size=total_batch_size,
padding_total=cfg.TRAIN.padding_minibatch, padding_total=cfg.TRAIN.padding_minibatch,
shuffle=True) shuffle=shuffle)
py_reader = model.py_reader py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader) py_reader.decorate_paddle_reader(train_reader)
else: else:
train_reader = reader.train(batch_size=total_batch_size, shuffle=True) train_reader = reader.train(batch_size=total_batch_size, shuffle=shuffle)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def save_model(postfix): def save_model(postfix):
...@@ -142,7 +148,7 @@ def train(): ...@@ -142,7 +148,7 @@ def train():
save_model("model_iter{}".format(iter_id)) save_model("model_iter{}".format(iter_id))
# only for ce # only for ce
if cfg.enable_ce: if cfg.enable_ce:
gpu_num = get_cards(cfg) gpu_num = devices_num
epoch_idx = iter_id + 1 epoch_idx = iter_id + 1
loss = last_loss loss = last_loss
print("kpis\teach_pass_duration_card%s\t%s" % print("kpis\teach_pass_duration_card%s\t%s" %
...@@ -185,7 +191,7 @@ def train(): ...@@ -185,7 +191,7 @@ def train():
break break
# only for ce # only for ce
if cfg.enable_ce: if cfg.enable_ce:
gpu_num = get_cards(cfg) gpu_num = devices_num
epoch_idx = iter_id + 1 epoch_idx = iter_id + 1
loss = last_loss loss = last_loss
print("kpis\teach_pass_duration_card%s\t%s" % print("kpis\teach_pass_duration_card%s\t%s" %
...@@ -202,15 +208,6 @@ def train(): ...@@ -202,15 +208,6 @@ def train():
save_model('model_final') save_model('model_final')
def get_cards(cfg):
if cfg.enable_ce:
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(","))
return num
else:
return cfg.num_devices
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
print_arguments(args) print_arguments(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册