model.py 4.0 KB
Newer Older
P
peterzhang2029 已提交
1 2 3 4 5
from paddle import v2 as paddle
from paddle.v2 import layer
from paddle.v2 import evaluator
from paddle.v2.activation import Relu, Linear
from paddle.v2.networks import img_conv_group, simple_gru
P
peterzhang2029 已提交
6
from config import ModelConfig as conf
P
peterzhang2029 已提交
7 8 9 10 11


class Model(object):
    def __init__(self, num_classes, shape, is_infer=False):
        '''
P
peterzhang2029 已提交
12
        :param num_classes: The size of the character dict.
13
        :type num_classes: int
P
peterzhang2029 已提交
14
        :param shape: The size of the input images.
15
        :type shape: tuple of 2 int
P
peterzhang2029 已提交
16
        :param is_infer: For inference or not
17
        :type shape: bool
P
peterzhang2029 已提交
18 19 20 21 22 23 24 25 26 27
        '''
        self.num_classes = num_classes
        self.shape = shape
        self.is_infer = is_infer
        self.image_vector_size = shape[0] * shape[1]

        self.__declare_input_layers__()
        self.__build_nn__()

    def __declare_input_layers__(self):
P
peterzhang2029 已提交
28 29 30 31
        '''
        Define the input layer.
        '''
        # Image input as a float vector.
P
peterzhang2029 已提交
32 33 34 35 36 37
        self.image = layer.data(
            name='image',
            type=paddle.data_type.dense_vector(self.image_vector_size),
            height=self.shape[0],
            width=self.shape[1])

P
peterzhang2029 已提交
38 39
        # Label input as an ID list
        if not self.is_infer:
P
peterzhang2029 已提交
40 41 42 43 44
            self.label = layer.data(
                name='label',
                type=paddle.data_type.integer_value_sequence(self.num_classes))

    def __build_nn__(self):
P
peterzhang2029 已提交
45 46 47
        '''
        Build the network topology.
        '''
P
peterzhang2029 已提交
48
        # Get the image features with CNN.
P
peterzhang2029 已提交
49 50
        conv_features = self.conv_groups(self.image, conf.filter_num,
                                         conf.with_bn)
P
peterzhang2029 已提交
51

P
peterzhang2029 已提交
52
        # Expand the output of CNN into a sequence of feature vectors.
P
peterzhang2029 已提交
53 54
        sliced_feature = layer.block_expand(
            input=conv_features,
P
peterzhang2029 已提交
55 56 57 58 59
            num_channels=conf.num_channels,
            stride_x=conf.stride_x,
            stride_y=conf.stride_y,
            block_x=conf.block_x,
            block_y=conf.block_y)
P
peterzhang2029 已提交
60

P
peterzhang2029 已提交
61
        # Use RNN to capture sequence information forwards and backwards.
P
peterzhang2029 已提交
62 63
        gru_forward = simple_gru(
            input=sliced_feature, size=conf.hidden_size, act=Relu())
P
peterzhang2029 已提交
64
        gru_backward = simple_gru(
P
peterzhang2029 已提交
65 66 67 68
            input=sliced_feature,
            size=conf.hidden_size,
            act=Relu(),
            reverse=True)
P
peterzhang2029 已提交
69

P
peterzhang2029 已提交
70
        # Map the output of RNN to character distribution.
P
peterzhang2029 已提交
71 72 73 74 75 76 77 78 79
        self.output = layer.fc(
            input=[gru_forward, gru_backward],
            size=self.num_classes + 1,
            act=Linear())

        self.log_probs = paddle.layer.mixed(
            input=paddle.layer.identity_projection(input=self.output),
            act=paddle.activation.Softmax())

P
peterzhang2029 已提交
80 81
        # Use warp CTC to calculate cost for a CTC task.
        if not self.is_infer:
P
peterzhang2029 已提交
82 83 84 85
            self.cost = layer.warp_ctc(
                input=self.output,
                label=self.label,
                size=self.num_classes + 1,
P
peterzhang2029 已提交
86
                norm_by_times=conf.norm_by_times,
P
peterzhang2029 已提交
87
                blank=self.num_classes)
88 89 90

            self.eval = evaluator.ctc_error(input=self.output, label=self.label)

P
peterzhang2029 已提交
91
    def conv_groups(self, input, num, with_bn):
92
        '''
P
peterzhang2029 已提交
93 94 95
        :param input: Input layer.
        :type input: LayerOutput
        :param num: Number of the filters.
96
        :type num: int
P
peterzhang2029 已提交
97
        :param with_bn: Whether with batch normalization.
98 99 100 101
        :type with_bn: bool
        '''
        assert num % 4 == 0

P
peterzhang2029 已提交
102
        filter_num_list = conf.filter_num_list
103
        is_input_image = True
P
peterzhang2029 已提交
104
        tmp = input
105 106 107 108 109 110 111 112 113 114 115 116

        for num_filter in filter_num_list:

            if is_input_image:
                num_channels = 1
                is_input_image = False
            else:
                num_channels = None

            tmp = img_conv_group(
                input=tmp,
                num_channels=num_channels,
P
peterzhang2029 已提交
117
                conv_padding=conf.conv_padding,
118
                conv_num_filter=[num_filter] * (num / 4),
P
peterzhang2029 已提交
119
                conv_filter_size=conf.conv_filter_size,
120 121
                conv_act=Relu(),
                conv_with_batchnorm=with_bn,
P
peterzhang2029 已提交
122 123
                pool_size=conf.pool_size,
                pool_stride=conf.pool_stride, )
124 125

        return tmp