network_conf.py 4.2 KB
Newer Older
P
peterzhang2029 已提交
1 2 3 4 5
from paddle import v2 as paddle
from paddle.v2 import layer
from paddle.v2 import evaluator
from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru
P
peterzhang2029 已提交
6
from config import ModelConfig as conf
P
peterzhang2029 已提交
7 8 9 10 11


class Model(object):
    def __init__(self, num_classes, shape, is_infer=False):
        '''
P
peterzhang2029 已提交
12
        :param num_classes: The size of the character dict.
13
        :type num_classes: int
P
peterzhang2029 已提交
14
        :param shape: The size of the input images.
15
        :type shape: tuple of 2 int
P
peterzhang2029 已提交
16 17
        :param is_infer: The boolean parameter indicating
                         inferring or training.
18
        :type shape: bool
P
peterzhang2029 已提交
19 20 21 22 23 24 25 26 27 28
        '''
        self.num_classes = num_classes
        self.shape = shape
        self.is_infer = is_infer
        self.image_vector_size = shape[0] * shape[1]

        self.__declare_input_layers__()
        self.__build_nn__()

    def __declare_input_layers__(self):
P
peterzhang2029 已提交
29 30 31 32
        '''
        Define the input layer.
        '''
        # Image input as a float vector.
P
peterzhang2029 已提交
33 34 35 36 37 38
        self.image = layer.data(
            name='image',
            type=paddle.data_type.dense_vector(self.image_vector_size),
            height=self.shape[0],
            width=self.shape[1])

P
peterzhang2029 已提交
39 40
        # Label input as an ID list
        if not self.is_infer:
P
peterzhang2029 已提交
41 42 43 44 45
            self.label = layer.data(
                name='label',
                type=paddle.data_type.integer_value_sequence(self.num_classes))

    def __build_nn__(self):
P
peterzhang2029 已提交
46 47 48
        '''
        Build the network topology.
        '''
P
peterzhang2029 已提交
49
        # Get the image features with CNN.
P
peterzhang2029 已提交
50 51
        conv_features = self.conv_groups(self.image, conf.filter_num,
                                         conf.with_bn)
P
peterzhang2029 已提交
52

P
peterzhang2029 已提交
53
        # Expand the output of CNN into a sequence of feature vectors.
P
peterzhang2029 已提交
54 55
        sliced_feature = layer.block_expand(
            input=conv_features,
P
peterzhang2029 已提交
56 57 58 59 60
            num_channels=conf.num_channels,
            stride_x=conf.stride_x,
            stride_y=conf.stride_y,
            block_x=conf.block_x,
            block_y=conf.block_y)
P
peterzhang2029 已提交
61

P
peterzhang2029 已提交
62
        # Use RNN to capture sequence information forwards and backwards.
P
peterzhang2029 已提交
63 64
        gru_forward = simple_gru(
            input=sliced_feature, size=conf.hidden_size, act=Relu())
P
peterzhang2029 已提交
65
        gru_backward = simple_gru(
P
peterzhang2029 已提交
66 67 68 69
            input=sliced_feature,
            size=conf.hidden_size,
            act=Relu(),
            reverse=True)
P
peterzhang2029 已提交
70

P
peterzhang2029 已提交
71
        # Map the output of RNN to character distribution.
72 73 74
        self.output = layer.fc(input=[gru_forward, gru_backward],
                               size=self.num_classes + 1,
                               act=Linear())
P
peterzhang2029 已提交
75 76 77 78 79

        self.log_probs = paddle.layer.mixed(
            input=paddle.layer.identity_projection(input=self.output),
            act=paddle.activation.Softmax())

P
peterzhang2029 已提交
80 81
        # Use warp CTC to calculate cost for a CTC task.
        if not self.is_infer:
P
peterzhang2029 已提交
82 83 84 85
            self.cost = layer.warp_ctc(
                input=self.output,
                label=self.label,
                size=self.num_classes + 1,
P
peterzhang2029 已提交
86
                norm_by_times=conf.norm_by_times,
P
peterzhang2029 已提交
87
                blank=self.num_classes)
88 89 90

            self.eval = evaluator.ctc_error(input=self.output, label=self.label)

P
peterzhang2029 已提交
91
    def conv_groups(self, input, num, with_bn):
92
        '''
P
peterzhang2029 已提交
93 94
        Get the image features with image convolution group.

P
peterzhang2029 已提交
95 96 97
        :param input: Input layer.
        :type input: LayerOutput
        :param num: Number of the filters.
98
        :type num: int
P
peterzhang2029 已提交
99
        :param with_bn: Use batch normalization or not.
100 101 102 103
        :type with_bn: bool
        '''
        assert num % 4 == 0

P
peterzhang2029 已提交
104
        filter_num_list = conf.filter_num_list
105
        is_input_image = True
P
peterzhang2029 已提交
106
        tmp = input
107 108 109 110 111 112 113 114 115 116 117 118

        for num_filter in filter_num_list:

            if is_input_image:
                num_channels = 1
                is_input_image = False
            else:
                num_channels = None

            tmp = img_conv_group(
                input=tmp,
                num_channels=num_channels,
P
peterzhang2029 已提交
119
                conv_padding=conf.conv_padding,
120
                conv_num_filter=[num_filter] * (num / 4),
P
peterzhang2029 已提交
121
                conv_filter_size=conf.conv_filter_size,
122 123
                conv_act=Relu(),
                conv_with_batchnorm=with_bn,
P
peterzhang2029 已提交
124 125
                pool_size=conf.pool_size,
                pool_stride=conf.pool_stride, )
126 127

        return tmp