未验证 提交 cf872f91 编写于 作者: Q qingqing01 提交者: GitHub

Remove the un-used code (#194)

上级 ccf94523
architecture: FasterRCNN
use_gpu: true
max_iters: 180000
log_smooth_window: 20
save_dir: output
snapshot_iter: 10000
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r50_1x/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
rpn_head: RPNHead
roi_extractor: RoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
norm_type: affine_channel
depth: 50
feature_maps: 4
freeze_at: 2
ResNetC5:
depth: 50
norm_type: affine_channel
RPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
use_random: true
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 12000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 6000
post_nms_top_n: 1000
RoIAlign:
resolution: 14
sampling_ratio: 0
spatial_scale: 0.0625
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: ResNetC5
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [120000, 160000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_LOADER_: 'faster_reader.yml'
TrainLoader:
inputs_def:
image_shape: [3,800,800]
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
batch_size: 3
TrainReader:
inputs_def:
image_shape: [3,NULL,NULL]
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 1
shuffle: true
worker_num: 2
drop_last: false
use_multi_process: false
EvalReader:
inputs_def:
image_shape: [3,800,1333]
fields: ['image', 'im_info', 'im_id', 'im_shape']
# for voc
#fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
#sample_num: 100
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
drop_last: false
# worker_num: 2
TestReader:
inputs_def:
image_shape: [3,800,1333]
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
drop_last: false
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import multiprocessing
import numpy as np
import datetime
from collections import deque
import sys
sys.path.append("../../")
from paddle.fluid.contrib.slim import Compressor
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
from paddle import fluid
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches, eval_results
from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
import ppdet.utils.checkpoint as checkpoint
from ppdet.modeling.model_input import create_feed
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def eval_run(exe, compile_program, reader, keys, values, cls, test_feed):
"""
Run evaluation program, return program outputs.
"""
iter_id = 0
results = []
images_num = 0
start_time = time.time()
has_bbox = 'bbox' in keys
for data in reader():
data = test_feed.feed(data)
feed_data = {'image': data['image'], 'im_size': data['im_size']}
outs = exe.run(compile_program,
feed=feed_data,
fetch_list=values[0],
return_numpy=False)
outs.append(data['gt_box'])
outs.append(data['gt_label'])
outs.append(data['is_difficult'])
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
iter_id += 1
images_num += len(res['bbox'][1][0]) if has_bbox else 1
logger.info('Test finish iter {}'.format(iter_id))
end_time = time.time()
fps = images_num / (end_time - start_time)
if has_bbox:
logger.info('Total number of images: {}, inference time: {} fps.'.
format(images_num, fps))
else:
logger.info('Total iteration: {}, inference time: {} batch/s.'.format(
images_num, fps))
return results
def main():
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
if 'log_iter' not in cfg:
cfg.log_iter = 20
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed')
else:
eval_feed = create(cfg.eval_feed)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
_, test_feed_vars = create_feed(eval_feed, False)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
test_data_feed = fluid.DataFeeder(test_feed_vars.values(), place)
assert os.path.exists(FLAGS.model_path)
infer_prog, feed_names, fetch_targets = fluid.io.load_inference_model(
dirname=FLAGS.model_path,
executor=exe,
model_filename=FLAGS.model_name,
params_filename=FLAGS.params_name)
eval_keys = ['bbox', 'gt_box', 'gt_label', 'is_difficult']
eval_values = [
'multiclass_nms_0.tmp_0', 'gt_box', 'gt_label', 'is_difficult'
]
eval_cls = []
eval_values[0] = fetch_targets[0]
results = eval_run(exe, infer_prog, eval_reader, eval_keys, eval_values,
eval_cls, test_data_feed)
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution,
False, FLAGS.output_eval)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-m", "--model_path", default=None, type=str, help="path of checkpoint")
parser.add_argument(
"--output_eval",
default=None,
type=str,
help="Evaluation directory, default is current directory.")
parser.add_argument(
"-d",
"--dataset_dir",
default=None,
type=str,
help="Dataset path, same as DataFeed.dataset.dataset_dir")
parser.add_argument(
"--model_name",
default='model',
type=str,
help="model file name to load_inference_model")
parser.add_argument(
"--params_name",
default='params',
type=str,
help="params file name to load_inference_model")
FLAGS = parser.parse_args()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import glob
import time
import numpy as np
from PIL import Image
sys.path.append("../../")
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be set before
# `import paddle`. Otherwise, it would not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
from paddle import fluid
from ppdet.utils.cli import print_total_cfg
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
from ppdet.utils.visualizer import visualize_results
import ppdet.utils.checkpoint as checkpoint
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
images = []
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
images.append(infer_img)
return images
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext)))
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def main():
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
# print_total_cfg(cfg)
if 'test_feed' not in cfg:
test_feed = create(main_arch + 'TestFeed')
else:
test_feed = create(cfg.test_feed)
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
test_feed.dataset.add_images(test_images)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
infer_prog, feed_var_names, fetch_list = fluid.io.load_inference_model(
dirname=FLAGS.model_path,
model_filename=FLAGS.model_name,
params_filename=FLAGS.params_name,
executor=exe)
reader = create_reader(test_feed)
feeder = fluid.DataFeeder(
place=place, feed_list=feed_var_names, program=infer_prog)
# parse infer fetches
assert cfg.metric in ['COCO', 'VOC'], \
"unknown metric type {}".format(cfg.metric)
extra_keys = []
if cfg['metric'] == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg['metric'] == 'VOC':
extra_keys = ['im_id', 'im_shape']
keys, values, _ = parse_fetches({
'bbox': fetch_list
}, infer_prog, extra_keys)
# parse dataset category
if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
if cfg.metric == "VOC":
from ppdet.utils.voc_eval import bbox2out, get_category_info
anno_file = getattr(test_feed.dataset, 'annotation', None)
with_background = getattr(test_feed, 'with_background', True)
use_default_label = getattr(test_feed, 'use_default_label', False)
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
# use tb-paddle to log image
if FLAGS.use_tb:
from tb_paddle import SummaryWriter
tb_writer = SummaryWriter(FLAGS.tb_log_dir)
tb_image_step = 0
tb_image_frame = 0 # each frame can display ten pictures at most.
imid2path = reader.imid2path
keys = ['bbox']
infer_time = True
compile_prog = fluid.compiler.CompiledProgram(infer_prog)
for iter_id, data in enumerate(reader()):
feed_data = [[d[0], d[1]] for d in data]
# for infer time
if infer_time:
warmup_times = 10
repeats_time = 100
feed_data_dict = feeder.feed(feed_data)
for i in range(warmup_times):
exe.run(compile_prog,
feed=feed_data_dict,
fetch_list=fetch_list,
return_numpy=False)
start_time = time.time()
for i in range(repeats_time):
exe.run(compile_prog,
feed=feed_data_dict,
fetch_list=fetch_list,
return_numpy=False)
print("infer time: {} ms/sample".format((time.time() - start_time) *
1000 / repeats_time))
infer_time = False
outs = exe.run(compile_prog,
feed=feeder.feed(feed_data),
fetch_list=fetch_list,
return_numpy=False)
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
res['im_id'] = [[d[2] for d in data]]
logger.info('Infer iter {}'.format(iter_id))
bbox_results = None
mask_results = None
if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res:
mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution)
# visualize result
im_ids = res['im_id'][0]
for im_id in im_ids:
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
# use tb-paddle to log original image
if FLAGS.use_tb:
original_image_np = np.array(image)
tb_writer.add_image(
"original/frame_{}".format(tb_image_frame),
original_image_np,
tb_image_step,
dataformats='HWC')
image = visualize_results(image,
int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results,
mask_results)
# use tb-paddle to log image with bbox
if FLAGS.use_tb:
infer_image_np = np.array(image)
tb_writer.add_image(
"bbox/frame_{}".format(tb_image_frame),
infer_image_np,
tb_image_step,
dataformats='HWC')
tb_image_step += 1
if tb_image_step % 10 == 0:
tb_image_step = 0
tb_image_frame += 1
save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name, quality=95)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--use_tb",
type=bool,
default=False,
help="whether to record the data to Tensorboard.")
parser.add_argument(
'--tb_log_dir',
type=str,
default="tb_log_dir/image",
help='Tensorboard logging directory for image.')
parser.add_argument(
'--model_path', type=str, default=None, help="inference model path")
parser.add_argument(
'--model_name',
type=str,
default='__model__.infer',
help="model filename for inference model")
parser.add_argument(
'--params_name',
type=str,
default='__params__',
help="params filename for inference model")
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册