未验证 提交 0d21ed97 编写于 作者: Y Yang Zhang 提交者: GitHub

Add an ipython notebook demo (#2553)

* Fix missing `absolute_import`

* Fix cycle import

* Fix missing export

* Remove some extra blank line and whitespaces

* Rename `create_feeds` -> `create_feed`

* Fix minor issues in tool scripts

* Add ipython notebook demo

* Tweak visualization bbox label style
上级 334b0d8c
......@@ -20,7 +20,6 @@ MaskRCNN:
mask_assigner: MaskAssigner
mask_head: MaskHead
ResNet:
norm_type: affine_channel
norm_decay: true
......
......@@ -4,11 +4,11 @@ eval_feed: MaskRCNNEvalFeed
test_feed: MaskRCNNTestFeed
max_iters: 260000
snapshot_iter: 10000
use_gpu: True
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/SE154_vd_pretrained.tar
weights: output/mask_rcnn_se154_vd_fpn_s1x/model_final/
weights: output/mask_rcnn_se154_vd_fpn_s1x/model_final/
metric: COCO
MaskRCNN:
......@@ -26,7 +26,7 @@ SENet:
group_width: 4
groups: 64
norm_type: affine_channel
variant: d
variant: d
FPN:
max_level: 6
......@@ -124,7 +124,7 @@ MaskRCNNTrainFeed:
- !PadBatch
pad_to_stride: 32
dataset:
dataset_dir: data/coco
dataset_dir: data/coco
image_dir: train2017
annotation: annotations/instances_train2017.json
num_workers: 2
......
此差异已折叠。
......@@ -14,4 +14,4 @@
import ppdet.modeling
import ppdet.optimizer
import ppdet.data.data_feed
import ppdet.data
......@@ -33,6 +33,8 @@
# * 'MappedDataset' accept a 'xxxSource' and a list of 'xxxOperator'
# to build a transformed 'Dataset'
from __future__ import absolute_import
from .dataset import Dataset
from .reader import Reader
from .data_feed import create_reader
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
# XXX for triggering decorators
from . import anchor_heads
from . import architectures
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import rpn_head
from . import yolo_head
from . import retina_head
......
......@@ -22,9 +22,10 @@ from paddle.fluid.initializer import Normal
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling.ops import AnchorGenerator, RPNTargetAssign, GenerateProposals
from ppdet.modeling.ops import (AnchorGenerator,
RPNTargetAssign, GenerateProposals)
__all__ = ['RPNTargetAssign', 'GenerateProposals', 'RPNHead']
__all__ = ['RPNTargetAssign', 'GenerateProposals', 'RPNHead', 'FPNRPNHead']
@register
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import faster_rcnn
from . import mask_rcnn
from . import cascade_rcnn
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import resnet
from . import resnext
from . import darknet
......
......@@ -20,7 +20,7 @@ from collections import OrderedDict
from paddle import fluid
__all__ = ['create_feeds']
__all__ = ['create_feed']
# yapf: disable
feed_var_def = [
......@@ -37,7 +37,7 @@ feed_var_def = [
# yapf: enable
def create_feeds(feed, use_pyreader=True):
def create_feed(feed, use_pyreader=True):
image_shape = feed.image_shape
feed_var_map = {var['name']: var for var in feed_var_def}
feed_var_map['image'] = {
......
......@@ -12,5 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import roi_extractor
from .roi_extractor import *
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import bbox_head
from . import mask_head
from . import cascade_head
......
......@@ -26,7 +26,7 @@ from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import MultiClassNMS
from ppdet.core.workspace import register, serializable
__all__ = ['BBoxHead']
__all__ = ['BBoxHead', 'TwoFCHead']
@register
......
......@@ -23,7 +23,7 @@ import paddle.fluid as fluid
from ppdet.modeling.tests.decorator_helper import prog_scope
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feeds
from ppdet.modeling.model_input import create_feed
class TestFasterRCNN(unittest.TestCase):
......@@ -39,14 +39,14 @@ class TestFasterRCNN(unittest.TestCase):
def test_train(self):
train_feed = create(self.cfg['train_feed'])
model = create(self.detector_type)
_, feed_vars = create_feeds(train_feed)
_, feed_vars = create_feed(train_feed)
train_fetches = model.train(feed_vars)
@prog_scope()
def test_test(self):
test_feed = create(self.cfg['eval_feed'])
model = create(self.detector_type)
_, feed_vars = create_feeds(test_feed)
_, feed_vars = create_feed(test_feed)
test_fetches = model.eval(feed_vars)
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import numpy as np
import pycocotools.mask as mask_util
from PIL import Image, ImageDraw
......@@ -26,8 +25,6 @@ from .colormap import colormap
__all__ = ['visualize_results']
logger = logging.getLogger(__name__)
def visualize_results(image,
im_id,
......@@ -73,7 +70,7 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7):
return Image.fromarray(img_array.astype('uint8'))
def draw_bbox(image, im_id, catid2name, bboxes, threshold,
def draw_bbox(image, im_id, catid2name, bboxes, threshold,
is_bbox_normalized=False):
"""
Draw bbox on image
......@@ -103,8 +100,11 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold,
width=2,
fill='red')
if image.mode == 'RGB':
draw.text((xmin, ymin), catid2name[catid], (255, 255, 0))
logger.debug("\t {:15s} at {:25} score: {:.5f}".format(catid2name[catid],
str(list(map(int, list([xmin, ymin, xmax, ymax])))), score))
text = catid2name[catid]
tw, th = draw.textsize(text)
draw.rectangle([(xmin + 1, ymin + 1),
(xmin + tw + 1, ymin + th + 1)],
fill='red')
draw.text((xmin + 1, ymin + 1), text, fill=(255, 255, 255))
return image
......@@ -24,7 +24,7 @@ import paddle.fluid as fluid
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.cli import ArgsParser
from ppdet.modeling.model_input import create_feeds
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
from ppdet.core.workspace import load_config, merge_config, create
......@@ -68,7 +68,7 @@ def main():
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
pyreader, feed_vars = create_feeds(eval_feed)
pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
......@@ -100,7 +100,7 @@ def main():
results = eval_run(exe, compile_program, pyreader, keys, values, cls)
# Evaluation
eval_results(results, eval_feed, cfg.metric,
cfg.MaskHead.resolution, FLAGS.output_file)
model.mask_head.resolution, FLAGS.output_file)
if __name__ == '__main__':
......
......@@ -25,7 +25,7 @@ from PIL import Image
from paddle import fluid
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feeds
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches
......@@ -55,7 +55,7 @@ def get_test_images(infer_dir, infer_img):
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer-img or --infer-dir should be set"
"--infer_img or --infer_dir should be set"
images = []
# infer_img has a higher priority
......@@ -104,7 +104,7 @@ def main():
infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
_, feed_vars = create_feeds(test_feed, use_pyreader=False)
_, feed_vars = create_feed(test_feed, use_pyreader=False)
test_fetches = model.test(feed_vars)
infer_prog = infer_prog.clone(True)
......@@ -151,11 +151,11 @@ def main():
mask_results = None
is_bbox_normalized = True if cfg.metric == 'VOC' else False
if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid,
bbox_results = bbox2out([res], clsid2catid,
is_bbox_normalized)
if 'mask' in res:
mask_results = mask2out([res], clsid2catid,
cfg.MaskHead['resolution'])
model.mask_head.resolution)
# visualize result
im_ids = res['im_id'][0]
......
......@@ -31,7 +31,7 @@ from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser
import ppdet.utils.checkpoint as checkpoint
from ppdet.modeling.model_input import create_feeds
from ppdet.modeling.model_input import create_feed
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
......@@ -77,7 +77,7 @@ def main():
train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
train_pyreader, feed_vars = create_feeds(train_feed)
train_pyreader, feed_vars = create_feed(train_feed)
train_fetches = model.train(feed_vars)
loss = train_fetches['loss']
lr = lr_builder()
......@@ -95,7 +95,7 @@ def main():
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
eval_pyreader, feed_vars = create_feeds(eval_feed)
eval_pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
......@@ -156,7 +156,7 @@ def main():
eval_keys, eval_values, eval_cls)
# Evaluation
eval_results(results, eval_feed, cfg.metric,
cfg.MaskHead.resolution, FLAGS.output_file)
model.mask_head.resolution, FLAGS.output_file)
checkpoint.save(exe, train_prog, os.path.join(save_dir, "model_final"))
train_pyreader.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册