classifier.py 3.0 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model for classifier."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle.fluid as fluid

from model.bert import BertModel


25 26 27
def create_model(args, bert_config, num_labels, is_prediction=False):
    input_fields = {
        'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'labels'],
Y
Yibing Liu 已提交
28 29
        'shapes': [[None, None], [None, None], [None, None],
                   [None, args.max_seq_len, 1], [None, 1]],
30 31 32 33 34
        'dtypes': ['int64', 'int64', 'int64', 'float32', 'int64'],
        'lod_levels': [0, 0, 0, 0, 0],
    }

    inputs = [
Y
Yibing Liu 已提交
35
        fluid.data(
36 37 38 39 40 41 42 43
            name=input_fields['names'][i],
            shape=input_fields['shapes'][i],
            dtype=input_fields['dtypes'][i],
            lod_level=input_fields['lod_levels'][i])
        for i in range(len(input_fields['names']))
    ]
    (src_ids, pos_ids, sent_ids, input_mask, labels) = inputs

Y
Yibing Liu 已提交
44 45
    data_loader = fluid.io.DataLoader.from_generator(
        feed_list=inputs, capacity=50, iterable=False)
Y
Yibing Liu 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

    bert = BertModel(
        src_ids=src_ids,
        position_ids=pos_ids,
        sentence_ids=sent_ids,
        input_mask=input_mask,
        config=bert_config,
        use_fp16=args.use_fp16)

    cls_feats = bert.get_pooled_output()
    cls_feats = fluid.layers.dropout(
        x=cls_feats,
        dropout_prob=0.1,
        dropout_implementation="upscale_in_train")
    logits = fluid.layers.fc(
        input=cls_feats,
Y
Yibing Liu 已提交
62
        num_flatten_dims=2,
Y
Yibing Liu 已提交
63 64 65 66 67 68 69 70 71 72 73 74
        size=num_labels,
        param_attr=fluid.ParamAttr(
            name="cls_out_w",
            initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
        bias_attr=fluid.ParamAttr(
            name="cls_out_b", initializer=fluid.initializer.Constant(0.)))

    if is_prediction:
        probs = fluid.layers.softmax(logits)
        feed_targets_name = [
            src_ids.name, pos_ids.name, sent_ids.name, input_mask.name
        ]
Y
Yibing Liu 已提交
75
        return data_loader, probs, feed_targets_name
Y
Yibing Liu 已提交
76

Y
Yibing Liu 已提交
77
    logits = fluid.layers.reshape(logits, [-1, num_labels], inplace=True)
Y
Yibing Liu 已提交
78 79 80 81 82 83 84
    ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
        logits=logits, label=labels, return_softmax=True)
    loss = fluid.layers.mean(x=ce_loss)

    num_seqs = fluid.layers.create_tensor(dtype='int64')
    accuracy = fluid.layers.accuracy(input=probs, label=labels, total=num_seqs)

Y
Yibing Liu 已提交
85
    return data_loader, loss, probs, accuracy, num_seqs