creator.py 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: UTF-8 -*-
"""
The function lex_net(args) define the lexical analysis network structure
"""
import sys
import os
import math

import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer

from reader import Dataset
sys.path.append("..")
from models.sequence_labeling import nets
from models.representation.ernie import ernie_encoder
from preprocess.ernie import task_reader

def create_model(args,  vocab_size, num_labels, mode = 'train'):
    """create lac model"""

    # model's input data
    words = fluid.layers.data(name='words', shape=[-1, 1], dtype='int64',lod_level=1)
    targets = fluid.layers.data(name='targets', shape=[-1, 1], dtype='int64', lod_level= 1)

    # for inference process
    if mode=='infer':
        crf_decode = nets.lex_net(words, args, vocab_size, num_labels, for_infer=True, target=None)
        return { "feed_list":[words],"words":words, "crf_decode":crf_decode,}

    # for test or train process
    avg_cost, crf_decode = nets.lex_net(words, args, vocab_size, num_labels, for_infer=False, target=targets)

    (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
     num_correct_chunks) = fluid.layers.chunk_eval(
        input=crf_decode,
        label=targets,
        chunk_scheme="IOB",
        num_chunk_types=int(math.ceil((num_labels - 1) / 2.0)))
    chunk_evaluator = fluid.metrics.ChunkEvaluator()
    chunk_evaluator.reset()

    ret = {
        "feed_list":[words, targets],
        "words": words,
        "targets": targets,
        "avg_cost":avg_cost,
        "crf_decode": crf_decode,
K
kinghuin 已提交
62
        "precision": precision,
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        "recall": recall,
        "f1_score": f1_score,
        "chunk_evaluator": chunk_evaluator,
        "num_infer_chunks": num_infer_chunks,
        "num_label_chunks": num_label_chunks,
        "num_correct_chunks": num_correct_chunks
    }
    return  ret



def create_pyreader(args, file_name, feed_list, place, model='lac', reader=None, return_reader=False, mode='train'):
    # init reader
    pyreader = fluid.io.PyReader(
        feed_list=feed_list,
        capacity=300,
        use_double_buffer=True,
        iterable=True
    )
    if model == 'lac':
        if reader==None:
            reader = Dataset(args)
        # create lac pyreader
        if mode == 'train':
            pyreader.decorate_sample_list_generator(
                paddle.batch(
                    paddle.reader.shuffle(
                        reader.file_reader(file_name),
                        buf_size=args.traindata_shuffle_buffer
                    ),
                    batch_size=args.batch_size
                ),
                places=place
            )
        else:
            pyreader.decorate_sample_list_generator(
                paddle.batch(
                    reader.file_reader(file_name, mode=mode),
                    batch_size=args.batch_size
                ),
                places=place
            )

    elif model == 'ernie':
        # create ernie pyreader
        if reader==None:
            reader = task_reader.SequenceLabelReader(
                vocab_path=args.vocab_path,
                label_map_config=args.label_map_config,
                max_seq_len=args.max_seq_len,
                do_lower_case=args.do_lower_case,
                in_tokens=False,
                random_seed=args.random_seed)

        if mode == 'train':
            pyreader.decorate_batch_generator(
                reader.data_generator(
                    file_name, args.batch_size, args.epoch, shuffle=True, phase="train"
                ),
                places=place
            )
        else:
            pyreader.decorate_batch_generator(
                reader.data_generator(
                    file_name, args.batch_size, epoch=1, shuffle=False, phase=mode
                ),
                places=place
            )

    if return_reader:
        return pyreader, reader
    else:
        return pyreader

def create_ernie_model(args, ernie_config):

    """
    Create Model for LAC based on ERNIE encoder
    """
    # ERNIE's input data
    src_ids = fluid.layers.data(name='src_ids', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
    sent_ids = fluid.layers.data(name='sent_ids', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
    pos_ids = fluid.layers.data(name='pos_ids', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
K
kinghuin 已提交
146
    input_mask = fluid.layers.data(name='input_mask', shape=[args.max_seq_len, 1], dtype='float32',lod_level=0)
147
    padded_labels =fluid.layers.data(name='padded_labels', shape=[args.max_seq_len, 1], dtype='int64',lod_level=0)
K
kinghuin 已提交
148 149
    seq_lens = fluid.layers.data(name='seq_lens', shape=[-1], dtype='int64',lod_level=0)
    squeeze_labels = fluid.layers.squeeze(padded_labels, axes=[-1])
150 151 152 153 154 155 156 157 158 159

    ernie_inputs = {
        "src_ids": src_ids,
        "sent_ids": sent_ids,
        "pos_ids": pos_ids,
        "input_mask": input_mask,
        "seq_lens": seq_lens
    }
    embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)

K
kinghuin 已提交
160
    padded_token_embeddings = embeddings["padded_token_embeddings"]
161 162 163

    emission = fluid.layers.fc(
        size=args.num_labels,
K
kinghuin 已提交
164
        input=padded_token_embeddings,
165 166 167 168
        param_attr=fluid.ParamAttr(
            initializer=fluid.initializer.Uniform(
                low=-args.init_bound, high=args.init_bound),
            regularizer=fluid.regularizer.L2DecayRegularizer(
K
kinghuin 已提交
169 170
                regularization_coeff=1e-4)),
        num_flatten_dims=2)
171 172 173

    crf_cost = fluid.layers.linear_chain_crf(
        input=emission,
K
kinghuin 已提交
174
        label=padded_labels,
175 176
        param_attr=fluid.ParamAttr(
            name='crfw',
K
kinghuin 已提交
177 178
            learning_rate=args.crf_learning_rate),
        length=seq_lens)
179 180
    avg_cost = fluid.layers.mean(x=crf_cost)
    crf_decode = fluid.layers.crf_decoding(
K
kinghuin 已提交
181
            input=emission, param_attr=fluid.ParamAttr(name='crfw'),length=seq_lens)
182 183 184

    (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
     num_correct_chunks) = fluid.layers.chunk_eval(
K
kinghuin 已提交
185 186 187 188 189
        input=crf_decode,
        label=squeeze_labels,
        chunk_scheme="IOB",
        num_chunk_types=int(math.ceil((args.num_labels - 1) / 2.0)),
        seq_length=seq_lens)
190 191 192 193 194
    chunk_evaluator = fluid.metrics.ChunkEvaluator()
    chunk_evaluator.reset()

    ret = {
        "feed_list": [src_ids, sent_ids, pos_ids, input_mask, padded_labels, seq_lens],
K
kinghuin 已提交
195 196
        "words":src_ids,
        "labels":padded_labels,
197 198 199 200 201 202 203 204 205 206 207 208
        "avg_cost":avg_cost,
        "crf_decode":crf_decode,
        "precision" : precision,
        "recall": recall,
        "f1_score": f1_score,
        "chunk_evaluator":chunk_evaluator,
        "num_infer_chunks":num_infer_chunks,
        "num_label_chunks":num_label_chunks,
        "num_correct_chunks":num_correct_chunks
    }

    return ret