model_input.py 2.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from collections import OrderedDict

from paddle import fluid

23
__all__ = ['create_feed']
24 25 26 27 28 29 30 31 32 33

# yapf: disable
feed_var_def = [
    {'name': 'im_info',       'shape': [3],  'dtype': 'float32', 'lod_level': 0},
    {'name': 'im_id',         'shape': [1],  'dtype': 'int32',   'lod_level': 0},
    {'name': 'gt_box',        'shape': [4],  'dtype': 'float32', 'lod_level': 1},
    {'name': 'gt_label',      'shape': [1],  'dtype': 'int32',   'lod_level': 1},
    {'name': 'is_crowd',      'shape': [1],  'dtype': 'int32',   'lod_level': 1},
    {'name': 'gt_mask',       'shape': [2],  'dtype': 'float32', 'lod_level': 3},
    {'name': 'is_difficult',  'shape': [1],  'dtype': 'int32',   'lod_level': 1},
K
Kaipeng Deng 已提交
34
    {'name': 'gt_score',      'shape': [1],  'dtype': 'float32', 'lod_level': 0},
35 36 37 38 39
    {'name': 'im_shape',      'shape': [3],  'dtype': 'float32',   'lod_level': 0},
]
# yapf: enable


40
def create_feed(feed, use_pyreader=True):
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    image_shape = feed.image_shape
    feed_var_map = {var['name']: var for var in feed_var_def}
    feed_var_map['image'] = {
        'name': 'image',
        'shape': image_shape,
        'dtype': 'float32',
        'lod_level': 0
    }

    # YOLO var dim is fixed
    if getattr(feed, 'num_max_boxes', None) is not None:
        feed_var_map['gt_label']['shape'] = [feed.num_max_boxes]
        feed_var_map['gt_score']['shape'] = [feed.num_max_boxes]
        feed_var_map['gt_box']['shape'] = [feed.num_max_boxes, 4]
        feed_var_map['gt_label']['lod_level'] = 0
        feed_var_map['gt_score']['lod_level'] = 0
        feed_var_map['gt_box']['lod_level'] = 0
        feed_var_map['im_shape']['shape'] = [2]
59
        feed_var_map['im_shape']['dtype'] = 'int32'
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

    feed_vars = OrderedDict([(key, fluid.layers.data(
        name=feed_var_map[key]['name'],
        shape=feed_var_map[key]['shape'],
        dtype=feed_var_map[key]['dtype'],
        lod_level=feed_var_map[key]['lod_level'])) for key in feed.fields])

    pyreader = None
    if use_pyreader:
        pyreader = fluid.io.PyReader(
            feed_list=list(feed_vars.values()),
            capacity=64,
            use_double_buffer=True,
            iterable=False)
    return pyreader, feed_vars