未验证 提交 4f457075 编写于 作者: Y Yang Zhang 提交者: GitHub

Tweak command line scripts (#2517)

* Tweak command line scripts

enable fine grained control of command flags parsing
switch to attribute style of accessing config options

* Break down visualization function

decouple it from IO operations

* Move flag parsing out of `main()`

* Fix a bug where `None` is returned instead of `{}`

* Rename `save_xxx` to `output_xxx` in command line flags

could be confusing since checkpoint is stored in `save_dir`

* Support image file extensions in upper case
上级 60a0e779
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys from argparse import ArgumentParser, RawDescriptionHelpFormatter
from argparse import ArgumentParser, RawDescriptionHelpFormatter, REMAINDER
import yaml import yaml
__all__ = ['ColorTTY', 'ArgsParser']
class ColorTTY(object): class ColorTTY(object):
def __init__(self): def __init__(self):
...@@ -40,60 +41,39 @@ class ColorTTY(object): ...@@ -40,60 +41,39 @@ class ColorTTY(object):
return "[{}m{}".format(code, message) return "[{}m{}".format(code, message)
def parse_args(): class ArgsParser(ArgumentParser):
parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter)
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-s",
"--savefile",
default=None,
type=str,
help="Save json file name for evaluation, if not set, default files are bbox.json and mask.json."
)
parser.add_argument(
"-r",
"--resume_checkpoint",
default=None,
type=str,
help="The checkpoint path for resuming training.")
parser.add_argument(
"--eval",
action='store_true',
default=False,
help="Whether perform evaluation in train")
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Image directory path to perform inference.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path to perform inference, --infer-img has a higher priority than --image-dir")
parser.add_argument(
"-o", "--opt", nargs=REMAINDER, help="set configuration options")
args = parser.parse_args()
if args.config is None: def __init__(self):
raise ValueError("Please specify --config=configure_file_path.") super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument("-o", "--opt", nargs='*',
help="set configuration options")
cli_config = {} def parse_args(self, argv=None):
if 'opt' in vars(args) and args.opt is not None: args = super(ArgsParser, self).parse_args(argv)
for s in args.opt: assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip() s = s.strip()
k, v = s.split('=') k, v = s.split('=')
if '.' not in k: if '.' not in k:
cli_config[k] = v config[k] = v
else: else:
keys = k.split('.') keys = k.split('.')
cli_config[keys[0]] = {} config[keys[0]] = {}
cur = cli_config[keys[0]] cur = config[keys[0]]
for idx, key in enumerate(keys[1:]): for idx, key in enumerate(keys[1:]):
if idx == len(keys) - 2: if idx == len(keys) - 2:
cur[key] = yaml.load(v, Loader=yaml.Loader) cur[key] = yaml.load(v, Loader=yaml.Loader)
else: else:
cur[key] = {} cur[key] = {}
cur = cur[key] cur = cur[key]
args.cli_config = cli_config return config
return args
...@@ -48,7 +48,7 @@ def parse_fetches(fetches, prog=None, extra_keys=None): ...@@ -48,7 +48,7 @@ def parse_fetches(fetches, prog=None, extra_keys=None):
v.persistable = True v.persistable = True
keys.append(k) keys.append(k)
values.append(v.name) values.append(v.name)
except: except Exception:
pass pass
return keys, values, cls return keys, values, cls
...@@ -88,23 +88,21 @@ def eval_run(exe, compile_program, pyreader, keys, values, cls): ...@@ -88,23 +88,21 @@ def eval_run(exe, compile_program, pyreader, keys, values, cls):
return results return results
def eval_results(results, feed, args, cfg): def eval_results(results, feed, metric, resolution, output_file=None):
"""Evaluation for evaluation program results""" """Evaluation for evaluation program results"""
metric = cfg['metric']
if metric == 'COCO': if metric == 'COCO':
from ppdet.utils.coco_eval import bbox_eval, mask_eval from ppdet.utils.coco_eval import bbox_eval, mask_eval
anno_file = getattr(feed.dataset, 'annotation', None) anno_file = getattr(feed.dataset, 'annotation', None)
with_background = getattr(feed, 'with_background', True) with_background = getattr(feed, 'with_background', True)
savefile = 'bbox.json' output = 'bbox.json'
if args.savefile: if output_file:
savefile = '{}_bbox.json'.format(args.savefile) output = '{}_bbox.json'.format(output_file)
bbox_eval(results, anno_file, savefile, with_background) bbox_eval(results, anno_file, output, with_background)
if 'mask' in results[0]: if 'mask' in results[0]:
savefile = 'mask.json' output = 'mask.json'
if args.savefile: if output_file:
savefile = '{}_mask.json'.format(args.savefile) output = '{}_mask.json'.format(output_file)
mask_eval(results, anno_file, savefile, mask_eval(results, anno_file, output, resolution)
cfg['MaskHead']['resolution'])
else: else:
res = np.mean(results[-1]['accum_map'][0]) res = np.mean(results[-1]['accum_map'][0])
logger.info('Test mAP: {}'.format(res)) logger.info('Test mAP: {}'.format(res))
...@@ -17,22 +17,16 @@ from __future__ import division ...@@ -17,22 +17,16 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import os
import logging
import numpy as np import numpy as np
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from .colormap import colormap from .colormap import colormap
logger = logging.getLogger(__name__)
__all__ = ['visualize_results'] __all__ = ['visualize_results']
SAVE_HOME = 'output'
def visualize_results(image_path, def visualize_results(image,
catid2name, catid2name,
threshold=0.5, threshold=0.5,
bbox_results=None, bbox_results=None,
...@@ -40,19 +34,11 @@ def visualize_results(image_path, ...@@ -40,19 +34,11 @@ def visualize_results(image_path,
""" """
Visualize bbox and mask results Visualize bbox and mask results
""" """
if not os.path.exists(SAVE_HOME):
os.makedirs(SAVE_HOME)
logger.info("Image {} detect: ".format(image_path))
image = Image.open(image_path)
if mask_results: if mask_results:
image = draw_mask(image, mask_results, threshold) image = draw_mask(image, mask_results, threshold)
if bbox_results: if bbox_results:
image = draw_bbox(image, catid2name, bbox_results, threshold) image = draw_bbox(image, catid2name, bbox_results, threshold)
return image
save_name = get_save_image_name(image_path)
logger.info("Detection results save in {}\n".format(save_name))
image.save(save_name)
def draw_mask(image, segms, threshold, alpha=0.7): def draw_mask(image, segms, threshold, alpha=0.7):
...@@ -62,7 +48,7 @@ def draw_mask(image, segms, threshold, alpha=0.7): ...@@ -62,7 +48,7 @@ def draw_mask(image, segms, threshold, alpha=0.7):
im_width, im_height = image.size im_width, im_height = image.size
mask_color_id = 0 mask_color_id = 0
w_ratio = .4 w_ratio = .4
image = np.array(image).astype('float32') img_array = np.array(image).astype('float32')
for dt in np.array(segms): for dt in np.array(segms):
segm, score = dt['segmentation'], dt['score'] segm, score = dt['segmentation'], dt['score']
if score < threshold: if score < threshold:
...@@ -74,10 +60,9 @@ def draw_mask(image, segms, threshold, alpha=0.7): ...@@ -74,10 +60,9 @@ def draw_mask(image, segms, threshold, alpha=0.7):
for c in range(3): for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask) idx = np.nonzero(mask)
image[idx[0], idx[1], :] *= 1.0 - alpha img_array[idx[0], idx[1], :] *= 1.0 - alpha
image[idx[0], idx[1], :] += alpha * color_mask img_array[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(image.astype('uint8')) return Image.fromarray(img_array.astype('uint8'))
return image
def draw_bbox(image, catid2name, bboxes, threshold): def draw_bbox(image, catid2name, bboxes, threshold):
...@@ -101,17 +86,5 @@ def draw_bbox(image, catid2name, bboxes, threshold): ...@@ -101,17 +86,5 @@ def draw_bbox(image, catid2name, bboxes, threshold):
fill='red') fill='red')
if image.mode == 'RGB': if image.mode == 'RGB':
draw.text((xmin, ymin), catid2name[catid], (255, 255, 0)) draw.text((xmin, ymin), catid2name[catid], (255, 255, 0))
logger.info("\t {:15s} at {:25} score: {:.5f}".format(
catid2name[catid],
str(list(map(int, [xmin, ymin, xmax, ymax]))),
score))
return image return image
def get_save_image_name(image_path):
"""
Get save image name from source image path.
"""
image_name = image_path.split('/')[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(SAVE_HOME, "{}".format(name)) + ext
...@@ -23,7 +23,7 @@ import paddle.fluid as fluid ...@@ -23,7 +23,7 @@ import paddle.fluid as fluid
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.cli import parse_args from ppdet.utils.cli import ArgsParser
from ppdet.modeling.model_input import create_feeds from ppdet.modeling.model_input import create_feeds
from ppdet.data.data_feed import create_reader from ppdet.data.data_feed import create_reader
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
...@@ -38,28 +38,27 @@ def main(): ...@@ -38,28 +38,27 @@ def main():
""" """
Main evaluate function Main evaluate function
""" """
args = parse_args() cfg = load_config(FLAGS.config)
cfg = load_config(args.config)
if 'architecture' in cfg: if 'architecture' in cfg:
main_arch = cfg['architecture'] main_arch = cfg.architecture
else: else:
raise ValueError("'architecture' not specified in config file.") raise ValueError("'architecture' not specified in config file.")
merge_config(args.cli_config) merge_config(FLAGS.opt)
if cfg['use_gpu']: if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count() devices_num = fluid.core.get_cuda_device_count()
else: else:
devices_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) devices_num = int(os.environ.get('CPU_NUM',
multiprocessing.cpu_count()))
if 'eval_feed' not in cfg: if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed') eval_feed = create(main_arch + 'EvalFeed')
else: else:
eval_feed = create(cfg['eval_feed']) eval_feed = create(cfg.eval_feed)
# define executor # define executor
place = fluid.CUDAPlace(0) if cfg['use_gpu'] else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# 2. build program # 2. build program
...@@ -88,11 +87,11 @@ def main(): ...@@ -88,11 +87,11 @@ def main():
# 5. Load model # 5. Load model
exe.run(startup_prog) exe.run(startup_prog)
if cfg['weights']: if 'weights' in cfg:
checkpoint.load_pretrain(exe, eval_prog, cfg['weights']) checkpoint.load_pretrain(exe, eval_prog, cfg.weights)
extra_keys = [] extra_keys = []
if cfg['metric'] == 'COCO': if 'metric' in cfg and cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape'] extra_keys = ['im_info', 'im_id', 'im_shape']
keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys) keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys)
...@@ -100,8 +99,18 @@ def main(): ...@@ -100,8 +99,18 @@ def main():
# 6. Run # 6. Run
results = eval_run(exe, compile_program, pyreader, keys, values, cls) results = eval_run(exe, compile_program, pyreader, keys, values, cls)
# Evaluation # Evaluation
eval_results(results, eval_feed, args, cfg) eval_results(results, eval_feed, cfg.metric,
cfg.MaskHead.resolution, FLAGS.output_file)
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-f",
"--output_file",
default=None,
type=str,
help="Evaluation file name, default to bbox.json and mask.json."
)
FLAGS = parser.parse_args()
main() main()
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import glob import glob
import numpy as np import numpy as np
from PIL import Image
from paddle import fluid from paddle import fluid
...@@ -28,7 +29,7 @@ from ppdet.modeling.model_input import create_feeds ...@@ -28,7 +29,7 @@ from ppdet.modeling.model_input import create_feeds
from ppdet.data.data_feed import create_reader from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches from ppdet.utils.eval_utils import parse_fetches
from ppdet.utils.cli import parse_args from ppdet.utils.cli import ArgsParser
from ppdet.utils.visualizer import visualize_results from ppdet.utils.visualizer import visualize_results
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
...@@ -38,6 +39,17 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) ...@@ -38,6 +39,17 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = image_path.split('/')[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img): def get_test_images(infer_dir, infer_img):
""" """
Get image path list in TEST mode Get image path list in TEST mode
...@@ -54,36 +66,36 @@ def get_test_images(infer_dir, infer_img): ...@@ -54,36 +66,36 @@ def get_test_images(infer_dir, infer_img):
infer_dir = os.path.abspath(infer_dir) infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \ assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir) "infer_dir {} is not a directory".format(infer_dir)
for fmt in ['jpg', 'jpeg', 'png', 'bmp']: exts = ['jpg', 'jpeg', 'png', 'bmp']
images.extend(glob.glob('{}/*.{}'.format(infer_dir, fmt))) exts += [ext.upper() for ext in exts]
for ext in exts:
images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext)))
assert len(images) > 0, "no image found in {} with " \ assert len(images) > 0, "no image found in {}".format(infer_dir)
"extension {}".format(infer_dir, image_ext)
logger.info("Found {} inference images in total.".format(len(images))) logger.info("Found {} inference images in total.".format(len(images)))
return images return images
def main(): def main():
args = parse_args() cfg = load_config(FLAGS.config)
cfg = load_config(args.config)
if 'architecture' in cfg: if 'architecture' in cfg:
main_arch = cfg['architecture'] main_arch = cfg.architecture
else: else:
raise ValueError("'architecture' not specified in config file.") raise ValueError("'architecture' not specified in config file.")
merge_config(args.cli_config) merge_config(FLAGS.opt)
if 'test_feed' not in cfg: if 'test_feed' not in cfg:
test_feed = create(main_arch + 'TestFeed') test_feed = create(main_arch + 'TestFeed')
else: else:
test_feed = create(cfg['test_feed']) test_feed = create(cfg.test_feed)
test_images = get_test_images(args.infer_dir, args.infer_img) test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
test_feed.dataset.add_images(test_images) test_feed.dataset.add_images(test_images)
place = fluid.CUDAPlace(0) if cfg['use_gpu'] else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
model = create(main_arch) model = create(main_arch)
...@@ -100,8 +112,8 @@ def main(): ...@@ -100,8 +112,8 @@ def main():
feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values()) feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values())
exe.run(startup_prog) exe.run(startup_prog)
if cfg['weights']: if cfg.weights:
checkpoint.load_checkpoint(exe, infer_prog, cfg['weights']) checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
# parse infer fetches # parse infer fetches
extra_keys = [] extra_keys = []
...@@ -110,9 +122,9 @@ def main(): ...@@ -110,9 +122,9 @@ def main():
keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys) keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
# 6. Parse dataset category # 6. Parse dataset category
if cfg['metric'] == 'COCO': if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
if cfg['metric'] == "VOC": if cfg.metric == "VOC":
# TODO(dengkaipeng): add VOC metric process # TODO(dengkaipeng): add VOC metric process
pass pass
...@@ -134,21 +146,42 @@ def main(): ...@@ -134,21 +146,42 @@ def main():
im_id = int(res['im_id'][0]) im_id = int(res['im_id'][0])
image_path = imid2path[im_id] image_path = imid2path[im_id]
if cfg['metric'] == 'COCO': if cfg.metric == 'COCO':
bbox_results = None bbox_results = None
mask_results = None mask_results = None
if 'bbox' in res: if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid) bbox_results = bbox2out([res], clsid2catid)
if 'mask' in res: if 'mask' in res:
mask_results = mask2out([res], clsid2catid, mask_results = mask2out([res], clsid2catid,
cfg['MaskHead']['resolution']) cfg.MaskHead.resolution)
visualize_results(image_path, catid2name, 0.5, bbox_results, image = Image.open(image_path)
mask_results) image = visualize_results(image, catid2name, 0.5,
bbox_results, mask_results)
if cfg['metric'] == "VOC": save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name)
if cfg.metric == "VOC":
# TODO(dengkaipeng): add VOC metric process # TODO(dengkaipeng): add VOC metric process
pass pass
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
FLAGS = parser.parse_args()
main() main()
...@@ -29,7 +29,7 @@ from ppdet.data.data_feed import create_reader ...@@ -29,7 +29,7 @@ from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
from ppdet.utils.stats import TrainingStats from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import parse_args from ppdet.utils.cli import ArgsParser
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
from ppdet.modeling.model_input import create_feeds from ppdet.modeling.model_input import create_feeds
...@@ -40,33 +40,33 @@ logger = logging.getLogger(__name__) ...@@ -40,33 +40,33 @@ logger = logging.getLogger(__name__)
def main(): def main():
args = parse_args() cfg = load_config(FLAGS.config)
cfg = load_config(args.config)
if 'architecture' in cfg: if 'architecture' in cfg:
main_arch = cfg['architecture'] main_arch = cfg.architecture
else: else:
raise ValueError("'architecture' not specified in config file.") raise ValueError("'architecture' not specified in config file.")
merge_config(args.cli_config) merge_config(FLAGS.opt)
if cfg['use_gpu']: if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count() devices_num = fluid.core.get_cuda_device_count()
else: else:
devices_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) devices_num = int(os.environ.get('CPU_NUM',
multiprocessing.cpu_count()))
if 'train_feed' not in cfg: if 'train_feed' not in cfg:
train_feed = create(main_arch + 'TrainFeed') train_feed = create(main_arch + 'TrainFeed')
else: else:
train_feed = create(cfg['train_feed']) train_feed = create(cfg.train_feed)
if args.eval: if FLAGS.eval:
if 'eval_feed' not in cfg: if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed') eval_feed = create(main_arch + 'EvalFeed')
else: else:
eval_feed = create(cfg['eval_feed']) eval_feed = create(cfg.eval_feed)
place = fluid.CUDAPlace(0) if cfg['use_gpu'] else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
model = create(main_arch) model = create(main_arch)
...@@ -84,14 +84,14 @@ def main(): ...@@ -84,14 +84,14 @@ def main():
optimizer = optim_builder(lr) optimizer = optim_builder(lr)
optimizer.minimize(loss) optimizer.minimize(loss)
train_reader = create_reader(train_feed, cfg['max_iters'] * devices_num) train_reader = create_reader(train_feed, cfg.max_iters * devices_num)
train_pyreader.decorate_sample_list_generator(train_reader, place) train_pyreader.decorate_sample_list_generator(train_reader, place)
# parse train fetches # parse train fetches
train_keys, train_values, _ = parse_fetches(train_fetches) train_keys, train_values, _ = parse_fetches(train_fetches)
train_values.append(lr) train_values.append(lr)
if args.eval: if FLAGS.eval:
eval_prog = fluid.Program() eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog): with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -103,7 +103,7 @@ def main(): ...@@ -103,7 +103,7 @@ def main():
eval_pyreader.decorate_sample_list_generator(eval_reader, place) eval_pyreader.decorate_sample_list_generator(eval_reader, place)
# parse train fetches # parse train fetches
extra_keys = ['im_info', 'im_id'] if cfg['metric'] == 'COCO' else [] extra_keys = ['im_info', 'im_id'] if cfg.metric == 'COCO' else []
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys) extra_keys)
...@@ -116,27 +116,27 @@ def main(): ...@@ -116,27 +116,27 @@ def main():
train_compile_program = fluid.compiler.CompiledProgram( train_compile_program = fluid.compiler.CompiledProgram(
train_prog).with_data_parallel( train_prog).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
if args.eval: if FLAGS.eval:
eval_compile_program = fluid.compiler.CompiledProgram(eval_prog) eval_compile_program = fluid.compiler.CompiledProgram(eval_prog)
exe.run(startup_prog) exe.run(startup_prog)
freeze_bn = getattr(model.backbone, 'freeze_norm', False) freeze_bn = getattr(model.backbone, 'freeze_norm', False)
if args.resume_checkpoint: if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, args.resume_checkpoint) checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
elif cfg['pretrain_weights'] and freeze_bn: elif cfg.pretrain_weights and freeze_bn:
checkpoint.load_and_fusebn(exe, train_prog, cfg['pretrain_weights']) checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
elif cfg['pretrain_weights']: elif cfg.pretrain_weights:
checkpoint.load_pretrain(exe, train_prog, cfg['pretrain_weights']) checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights)
train_stats = TrainingStats(cfg['log_smooth_window'], train_keys) train_stats = TrainingStats(cfg.log_smooth_window, train_keys)
train_pyreader.start() train_pyreader.start()
start_time = time.time() start_time = time.time()
end_time = time.time() end_time = time.time()
cfg_name = os.path.basename(args.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg['save_dir'], cfg_name) save_dir = os.path.join(cfg.save_dir, cfg_name)
for it in range(cfg['max_iters']): for it in range(cfg.max_iters):
start_time = end_time start_time = end_time
end_time = time.time() end_time = time.time()
outs = exe.run(train_compile_program, fetch_list=train_values) outs = exe.run(train_compile_program, fetch_list=train_values)
...@@ -147,19 +147,40 @@ def main(): ...@@ -147,19 +147,40 @@ def main():
it, np.mean(outs[-1]), logs, end_time - start_time) it, np.mean(outs[-1]), logs, end_time - start_time)
logger.info(strs) logger.info(strs)
if it > 0 and it % cfg['snapshot_iter'] == 0: if it > 0 and it % cfg.snapshot_iter == 0:
checkpoint.save(exe, train_prog, os.path.join(save_dir, str(it))) checkpoint.save(exe, train_prog, os.path.join(save_dir, str(it)))
if args.eval: if FLAGS.eval:
# Run evaluation # Run evaluation
results = eval_run(exe, eval_compile_program, eval_pyreader, results = eval_run(exe, eval_compile_program, eval_pyreader,
eval_keys, eval_values, eval_cls) eval_keys, eval_values, eval_cls)
# Evaluation # Evaluation
eval_results(results, eval_feed, args, cfg) eval_results(results, eval_feed, cfg.metric,
cfg.MaskHead.resolution, FLAGS.output_file)
checkpoint.save(exe, train_prog, os.path.join(save_dir, "model_final")) checkpoint.save(exe, train_prog, os.path.join(save_dir, "model_final"))
train_pyreader.reset() train_pyreader.reset()
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-r",
"--resume_checkpoint",
default=None,
type=str,
help="Checkpoint path for resuming training.")
parser.add_argument(
"--eval",
action='store_true',
default=False,
help="Whether to perform evaluation in train")
parser.add_argument(
"-f",
"--output_file",
default=None,
type=str,
help="Evaluation file name, default to bbox.json and mask.json."
)
FLAGS = parser.parse_args()
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册