ernie_encoder.py 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""extract embeddings from ERNIE encoder."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import argparse
import numpy as np
import multiprocessing

C
chenxuyi 已提交
25
import logging
26 27 28
import paddle.fluid as fluid

import reader.task_reader as task_reader
C
chenxuyi 已提交
29 30
from model.ernie_v1 import ErnieConfig, ErnieModel
from utils.args import ArgumentGroup, print_arguments, prepare_logger
31 32
from utils.init import init_pretraining_params

C
chenxuyi 已提交
33 34
log = logging.getLogger()

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("ernie_config_path",         str,  None, "Path to the json file for ernie model config.")
model_g.add_arg("init_pretraining_params",   str,  None,
                "Init pre-training params which preforms fine-tuning from. If the "
                 "arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("output_dir",                str,  "embeddings", "path to save embeddings extracted by ernie_encoder.")

data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_set",            str,  None,  "Path to data for calculating ernie_embeddings.")
data_g.add_arg("vocab_path",          str,  None,  "Vocabulary path.")
data_g.add_arg("max_seq_len",         int,  512,   "Number of words of the longest seqence.")
data_g.add_arg("batch_size",          int,  32,    "Total examples' number in batch for training.")
data_g.add_arg("do_lower_case",       bool, True,
               "Whether to lower case the input text. Should be True for uncased models and False for cased models.")

run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda",                     bool,   True,  "If set, use GPU for training.")
# yapf: enable


T
tianxin04 已提交
57
def create_model(args, pyreader_name, ernie_config):
58 59 60 61 62 63 64 65 66 67
    src_ids = fluid.layers.data(name='1', shape=[-1, args.max_seq_len, 1], dtype='int64')
    sent_ids = fluid.layers.data(name='2', shape=[-1, args.max_seq_len, 1], dtype='int64')
    pos_ids = fluid.layers.data(name='3', shape=[-1, args.max_seq_len, 1], dtype='int64')
    task_ids = fluid.layers.data(name='4', shape=[-1, args.max_seq_len, 1], dtype='int64')
    input_mask = fluid.layers.data(name='5', shape=[-1, args.max_seq_len, 1], dtype='float32')
    seq_lens = fluid.layers.data(name='8', shape=[-1], dtype='int64')

    pyreader = fluid.io.DataLoader.from_generator(feed_list=[src_ids, sent_ids, pos_ids, task_ids, input_mask, seq_lens], 
            capacity=70,
            iterable=False)
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

    ernie = ErnieModel(
        src_ids=src_ids,
        position_ids=pos_ids,
        sentence_ids=sent_ids,
        input_mask=input_mask,
        config=ernie_config)

    enc_out = ernie.get_sequence_output()
    unpad_enc_out = fluid.layers.sequence_unpad(enc_out, length=seq_lens)
    cls_feats = ernie.get_pooled_output()

    # set persistable = True to avoid memory opimizing
    enc_out.persistable = True
    unpad_enc_out.persistable = True
    cls_feats.persistable = True

    graph_vars = {
        "cls_embeddings": cls_feats,
        "top_layer_embeddings": unpad_enc_out,
    }

    return pyreader, graph_vars


def main(args):
    args = parser.parse_args()
    ernie_config = ErnieConfig(args.ernie_config_path)
    ernie_config.print_config()

    if args.use_cuda:
        place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))

    exe = fluid.Executor(place)

    reader = task_reader.ExtractEmbeddingReader(
        vocab_path=args.vocab_path,
        max_seq_len=args.max_seq_len,
T
tianxin04 已提交
110
        do_lower_case=args.do_lower_case)
111 112 113 114 115 116 117

    startup_prog = fluid.Program()

    data_generator = reader.data_generator(
        input_file=args.data_set,
        batch_size=args.batch_size,
        epoch=1,
T
tianxin04 已提交
118
        shuffle=False)
119 120 121 122 123 124

    total_examples = reader.get_num_examples(args.data_set)

    print("Device count: %d" % dev_count)
    print("Total num examples: %d" % total_examples)

T
tianxin04 已提交
125
    infer_program = fluid.Program()
126

T
tianxin04 已提交
127
    with fluid.program_guard(infer_program, startup_prog):
128 129 130 131
        with fluid.unique_name.guard():
            pyreader, graph_vars = create_model(
                args, pyreader_name='reader', ernie_config=ernie_config)

T
tianxin04 已提交
132
    infer_program = infer_program.clone(for_test=True)
133 134 135 136 137 138 139 140 141 142 143 144 145

    exe.run(startup_prog)

    if args.init_pretraining_params:
        init_pretraining_params(
            exe, args.init_pretraining_params, main_program=startup_prog)
    else:
        raise ValueError(
            "WARNING: args 'init_pretraining_params' must be specified")

    exec_strategy = fluid.ExecutionStrategy()
    exec_strategy.num_threads = dev_count

146
    pyreader.set_batch_generator(data_generator)
147 148 149 150 151 152 153 154
    pyreader.start()

    total_cls_emb = []
    total_top_layer_emb = []
    total_labels = []
    while True:
        try:
            cls_emb, unpad_top_layer_emb = exe.run(
T
tianxin04 已提交
155
                program=infer_program,
156
                fetch_list=[
T
tianxin 已提交
157 158
                    graph_vars["cls_embeddings"].name,
                    graph_vars["top_layer_embeddings"].name
159 160 161 162 163 164 165 166 167 168 169
                ],
                return_numpy=False)
            # batch_size * embedding_size
            total_cls_emb.append(np.array(cls_emb))
            total_top_layer_emb.append(np.array(unpad_top_layer_emb))
        except fluid.core.EOFException:
            break

    total_cls_emb = np.concatenate(total_cls_emb)
    total_top_layer_emb = np.concatenate(total_top_layer_emb)

170 171 172 173 174
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    else:
        raise RuntimeError('output dir exists: %s' % args.output_dir)

175
    with open(os.path.join(args.output_dir, "cls_emb.npy"),
P
P01son6415 已提交
176
              "wb") as cls_emb_file:
177 178
        np.save(cls_emb_file, total_cls_emb)
    with open(os.path.join(args.output_dir, "top_layer_emb.npy"),
P
P01son6415 已提交
179
              "wb") as top_layer_emb_file:
180 181 182 183
        np.save(top_layer_emb_file, total_top_layer_emb)


if __name__ == '__main__':
C
chenxuyi 已提交
184
    prepare_logger(log)
185 186 187 188
    args = parser.parse_args()
    print_arguments(args)

    main(args)