model.py 4.1 KB
Newer Older
P
peterzhang2029 已提交
1 2 3 4 5
from paddle import v2 as paddle
from paddle.v2 import layer
from paddle.v2 import evaluator
from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru
P
peterzhang2029 已提交
6
from config import ModelConfig as conf
P
peterzhang2029 已提交
7 8 9 10 11


class Model(object):
    def __init__(self, num_classes, shape, is_infer=False):
        '''
P
peterzhang2029 已提交
12
        :param num_classes: The size of the character dict.
13
        :type num_classes: int
P
peterzhang2029 已提交
14
        :param shape: The size of the input images.
15
        :type shape: tuple of 2 int
P
peterzhang2029 已提交
16
        :param is_infer: For inference or not
17
        :type shape: bool
P
peterzhang2029 已提交
18 19 20 21 22 23 24 25 26 27
        '''
        self.num_classes = num_classes
        self.shape = shape
        self.is_infer = is_infer
        self.image_vector_size = shape[0] * shape[1]

        self.__declare_input_layers__()
        self.__build_nn__()

    def __declare_input_layers__(self):
P
peterzhang2029 已提交
28 29 30 31
        '''
        Define the input layer.
        '''
        # Image input as a float vector.
P
peterzhang2029 已提交
32 33 34 35 36 37
        self.image = layer.data(
            name='image',
            type=paddle.data_type.dense_vector(self.image_vector_size),
            height=self.shape[0],
            width=self.shape[1])

P
peterzhang2029 已提交
38 39
        # Label input as an ID list
        if not self.is_infer:
P
peterzhang2029 已提交
40 41 42 43 44
            self.label = layer.data(
                name='label',
                type=paddle.data_type.integer_value_sequence(self.num_classes))

    def __build_nn__(self):
P
peterzhang2029 已提交
45 46 47 48 49 50
        '''
        Build the network topology.
        '''
        # CNN output image features.
        conv_features = self.conv_groups(self.image, conf.filter_num,
                                         conf.with_bn)
P
peterzhang2029 已提交
51

P
peterzhang2029 已提交
52
        # Cut CNN output into a sequence of feature vectors, which are
P
peterzhang2029 已提交
53 54 55
        # 1 pixel wide and 11 pixel high.
        sliced_feature = layer.block_expand(
            input=conv_features,
P
peterzhang2029 已提交
56 57 58 59 60
            num_channels=conf.num_channels,
            stride_x=conf.stride_x,
            stride_y=conf.stride_y,
            block_x=conf.block_x,
            block_y=conf.block_y)
P
peterzhang2029 已提交
61 62

        # RNNs to capture sequence information forwards and backwards.
P
peterzhang2029 已提交
63 64
        gru_forward = simple_gru(
            input=sliced_feature, size=conf.hidden_size, act=Relu())
P
peterzhang2029 已提交
65
        gru_backward = simple_gru(
P
peterzhang2029 已提交
66 67 68 69
            input=sliced_feature,
            size=conf.hidden_size,
            act=Relu(),
            reverse=True)
P
peterzhang2029 已提交
70

P
peterzhang2029 已提交
71
        # Map each step of RNN to character distribution.
P
peterzhang2029 已提交
72 73 74 75 76 77 78 79 80
        self.output = layer.fc(
            input=[gru_forward, gru_backward],
            size=self.num_classes + 1,
            act=Linear())

        self.log_probs = paddle.layer.mixed(
            input=paddle.layer.identity_projection(input=self.output),
            act=paddle.activation.Softmax())

P
peterzhang2029 已提交
81 82
        # Use warp CTC to calculate cost for a CTC task.
        if not self.is_infer:
P
peterzhang2029 已提交
83 84 85 86
            self.cost = layer.warp_ctc(
                input=self.output,
                label=self.label,
                size=self.num_classes + 1,
P
peterzhang2029 已提交
87
                norm_by_times=conf.norm_by_times,
P
peterzhang2029 已提交
88
                blank=self.num_classes)
89 90 91

            self.eval = evaluator.ctc_error(input=self.output, label=self.label)

P
peterzhang2029 已提交
92
    def conv_groups(self, input, num, with_bn):
93
        '''
P
peterzhang2029 已提交
94 95 96
        :param input: Input layer.
        :type input: LayerOutput
        :param num: Number of the filters.
97
        :type num: int
P
peterzhang2029 已提交
98
        :param with_bn: Whether with batch normalization.
99 100 101 102
        :type with_bn: bool
        '''
        assert num % 4 == 0

P
peterzhang2029 已提交
103
        filter_num_list = conf.filter_num_list
104
        is_input_image = True
P
peterzhang2029 已提交
105
        tmp = input
106 107 108 109 110 111 112 113 114 115 116 117

        for num_filter in filter_num_list:

            if is_input_image:
                num_channels = 1
                is_input_image = False
            else:
                num_channels = None

            tmp = img_conv_group(
                input=tmp,
                num_channels=num_channels,
P
peterzhang2029 已提交
118
                conv_padding=conf.conv_padding,
119
                conv_num_filter=[num_filter] * (num / 4),
P
peterzhang2029 已提交
120
                conv_filter_size=conf.conv_filter_size,
121 122
                conv_act=Relu(),
                conv_with_batchnorm=with_bn,
P
peterzhang2029 已提交
123 124
                pool_size=conf.pool_size,
                pool_stride=conf.pool_stride, )
125 126

        return tmp