Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • ERNIE
  • Issue
  • #532

E
ERNIE
  • 项目概览

PaddlePaddle / ERNIE
大约 2 年 前同步成功

通知 115
Star 5997
Fork 1271
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 29
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 0
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
E
ERNIE
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 29
    • Issue 29
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 0
    • 合并请求 0
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 7月 29, 2020 by saxon_zh@saxon_zhGuest

完型填空的例子能正确运行,导出推理模型不报错,但是推理模型为无效模型。

Created by: hello-web

coding=gbk

import os import re import time import logging import json from random import random from tqdm import tqdm from functools import reduce, partial

import numpy as np import logging import argparse

import paddle import paddle.fluid as F import paddle.fluid.dygraph as FD import paddle.fluid.layers as L

from propeller import log import propeller.paddle as propeller

log.setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)

#from model.bert import BertConfig, BertModelLayer #from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer from ernie.optimization import AdamW, LinearDecay from ernie.modeling_ernie import ErnieModelForPretraining, ErnieModel

if name == 'main': parser = argparse.ArgumentParser('save_inference_model with ERNIE') parser.add_argument('--from_pretrained', type=str, required=True, help='pretrained model directory or tag') parser.add_argument('--inference_model_dir', type=str, default=None, help='inference model output directory')

args = parser.parse_args()

tokenizer = ErnieTokenizer.from_pretrained(args.from_pretrained)
rev_dict = {v: k for k, v in tokenizer.vocab.items()}
rev_dict[tokenizer.pad_id] = '' # replace [PAD]
rev_dict[tokenizer.sep_id] = '' # replace [PAD]
rev_dict[tokenizer.unk_id] = '' # replace [PAD]

@np.vectorize
def rev_lookup(i):
    return rev_dict[i]

place = F.CPUPlace()
with FD.guard(place):
    model = ErnieModelForPretraining.from_pretrained(args.from_pretrained)

    src, _ = tokenizer.encode('戊戌变法,又称百日维新,是 [MASK] [MASK] [MASK] 、梁启超等维新派人士通过光绪帝进行 的一场资产阶级改良。')
    print(src)
    src_ids = np.expand_dims(src, 0)
    src_ids = FD.to_variable(src_ids)

    if args.inference_model_dir is not None:
        log.debug('saving inference model')
        class InferemceModel(ErnieModelForPretraining):
            def __init__(self, *args, **kwargs):
                super(InferemceModel, self).__init__(*args, **kwargs)
                #del self.pooler_heads
            def forward(self, src_ids, *args, **kwargs):
                #mlm_pos = kwargs.pop('mlm_pos', np.int64(3))
                pooled, encoded = ErnieModel.forward(self, src_ids, *args, **kwargs)
                encoded_2d = L.gather_nd(encoded, L.where(src_ids == 3))
                encoded_2d = self.mlm(encoded_2d)
                encoded_2d = self.mlm_ln(encoded_2d)
                logits_2d = L.matmul(encoded_2d, self.word_emb.weight, transpose_y=True) + self.mlm_bias
                return logits_2d

        model.__class__ = InferemceModel 
        logits = model(src_ids)
        _, static_model = FD.TracedLayer.trace(model, inputs=[src_ids])
        static_model.save_inference_model(args.inference_model_dir)
        
        logits = logits.numpy()
        output_ids = np.argmax(logits, -1)
        seg_txt = rev_lookup(output_ids)
        print(''.join(seg_txt))

以上为复现代码 问题:完型填空的例子能正确运行,导出推理模型不报错,但是推理模型为无效模型,不能进行正确推理。 导出的模型文件参考文件名: model create_parameter_0.b_0 embedding_[0-2].w_0 layer_norm_[0-25].b_0 layer_norm_[0-25].w_0 linear_[0-74].b_0 linear_[0-74].w_0

指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/ERNIE#532
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7