未验证 提交 487e6e60 编写于 作者: W wangguanzhong 提交者: GitHub

share train_batch_size from reader to model (#1532)

* share train_batch_size from reader to model

* update floor_divide

* add makedirs for dumping config
上级 31ce8917
......@@ -31,7 +31,6 @@ FPN:
spatial_scale: [0.03125, 0.0625, 0.125]
CornerHead:
train_batch_size: 14
test_batch_size: 1
ae_threshold: 0.5
num_dets: 100
......
......@@ -31,7 +31,6 @@ FPN:
spatial_scale: [0.03125, 0.0625, 0.125]
CornerHead:
train_batch_size: 14
test_batch_size: 1
ae_threshold: 0.5
num_dets: 100
......
......@@ -20,7 +20,6 @@ Hourglass:
modules: [2, 2, 2, 2, 4]
CornerHead:
train_batch_size: 14
test_batch_size: 1
ae_threshold: 0.5
num_dets: 100
......
......@@ -30,7 +30,6 @@ FPN:
spatial_scale: [0.03125, 0.0625, 0.125]
CornerHead:
train_batch_size: 14
test_batch_size: 1
ae_threshold: 0.5
num_dets: 100
......
......@@ -110,9 +110,10 @@ def nms(heat):
def _topk(scores, batch_size, height, width, K):
scores_r = fluid.layers.reshape(scores, [batch_size, -1])
topk_scores, topk_inds = fluid.layers.topk(scores_r, K)
topk_clses = topk_inds / (height * width)
topk_inds = fluid.layers.cast(topk_inds, 'int32')
topk_clses = topk_inds // (height * width)
topk_inds = topk_inds % (height * width)
topk_ys = fluid.layers.cast(topk_inds / width, 'float32')
topk_ys = fluid.layers.cast(topk_inds // width, 'float32')
topk_xs = fluid.layers.cast(topk_inds % width, 'float32')
return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
......@@ -231,10 +232,10 @@ class CornerHead(object):
num_dets(int): num of detections, 1000 by default
top_k(int): choose top_k pair of corners in prediction, 100 by default
"""
__shared__ = ['num_classes', 'stack']
__shared__ = ['num_classes', 'stack', 'train_batch_size']
def __init__(self,
train_batch_size,
train_batch_size=14,
test_batch_size=1,
num_classes=80,
stack=2,
......@@ -480,12 +481,10 @@ class CornerHead(object):
cornerpool_lib.right_pool,
is_test=True,
name='br_modules_' + str(ind))
tl_heat = self.pred_mod(
tl_modules, self.num_classes, name='tl_heats_' + str(ind))
br_heat = self.pred_mod(
br_modules, self.num_classes, name='br_heats_' + str(ind))
tl_tag = self.pred_mod(tl_modules, 1, name='tl_tags_' + str(ind))
br_tag = self.pred_mod(br_modules, 1, name='br_tags_' + str(ind))
......
......@@ -95,8 +95,11 @@ def parse_reader(reader_cfg, metric, arch):
def dump_infer_config(FLAGS, config):
arch_state = 0
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
from ppdet.core.config.yaml_helpers import setup_orderdict
setup_orderdict()
infer_cfg = OrderedDict({
......@@ -121,7 +124,13 @@ def dump_infer_config(FLAGS, config):
if arch in infer_arch:
infer_cfg['arch'] = arch
infer_cfg['min_subgraph_size'] = min_subgraph_size
arch_state = 1
break
if not arch_state:
logger.error(
'Architecture: {} is not supported for exporting model now'.format(
infer_arch))
os._exit(0)
if 'Mask' in config['architecture']:
infer_cfg['mask_resolution'] = config['MaskHead']['resolution']
......@@ -206,8 +215,8 @@ def main():
exe.run(startup_prog)
checkpoint.load_params(exe, infer_prog, cfg.weights)
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
dump_infer_config(FLAGS, cfg)
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
if __name__ == '__main__':
......
......@@ -180,12 +180,17 @@ def main():
logger.info('Infer iter {}'.format(iter_id))
if 'TTFNet' in cfg.architecture:
res['bbox'][1].append([len(res['bbox'][0])])
if 'CornerNet' in cfg.architecture:
from ppdet.utils.post_process import corner_post_process
post_config = getattr(cfg, 'PostProcess', None)
corner_post_process(res, post_config, cfg.num_classes)
bbox_results = None
mask_results = None
lmk_results = None
if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res:
mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册