未验证 提交 b0c85c42 编写于 作者: S stevezhangz 提交者: GitHub

Add files via upload

上级 c3ebf570
[hyper_parameters]
maxlen = 150
batch_size = 6
max_pred = 5
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4
d_k = 64
d_v = 64
n_segments = 2
seed = 9527
lr = 0.0001
epc = 250
[file_dir]
data = ./checkpoint
[device]
force = cuda:0
tensor_dtype = torch.cuda.FloatTensor
confirm_gpu = 1
# encode utf-8
# code by steve zhang z
# Time: 4/22/2021
# electric address: stevezhangz@163.com
import configparser
conf=configparser.ConfigParser()
conf.read("Config.cfg")
maxlen = int(conf.get("hyper_parameters","maxlen"))
batch_size = int(conf.get("hyper_parameters","batch_size"))
max_pred = int(conf.get("hyper_parameters","max_pred"))
n_layers = int(conf.get("hyper_parameters","n_layers"))
n_heads = int(conf.get("hyper_parameters","n_heads"))
d_model = int(conf.get("hyper_parameters","d_model"))
d_ff =eval(conf.get("hyper_parameters","d_ff"))
d_k = int(conf.get("hyper_parameters","d_k"))
d_v=int(conf.get("hyper_parameters","d_v"))
n_segments = int(conf.get("hyper_parameters","n_segments"))
random_seed=int(conf.get("hyper_parameters","seed"))
lr=float(conf.get("hyper_parameters","lr"))
epoches=int(conf.get("hyper_parameters","epc"))
data_dir=conf.get("file_dir","data")
weight_dir=conf.get("file_dir","data")
device=conf.get("device","force")
default_tensor_type=conf.get("device","tensor_dtype")
use_gpu=int(conf.get("device","confirm_gpu"))
\ No newline at end of file
# encode utf-8
# code by steve zhang z
# Time: 4/22/2021
# electric address: stevezhangz@163.com
import torch
from torch import nn
import numpy as np
import os
import argparse
import math
class Grelu(nn.Module):
def __init__(self):
super(Grelu, self).__init__()
def forward(self,x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def grelu(x):
# GAUSSIANERRORLINEARUNITS(GELUS)
# url: https://arxiv.org/pdf/1606.08415.pdf
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class Embedding(nn.Module):
def __init__(self,
vocab_size,
emb_size,
max_len,
seg_size):
super(Embedding, self).__init__()
self.emb_x=nn.Embedding(vocab_size,emb_size)
self.emb_pos=nn.Embedding(max_len,emb_size)
self.emb_seg=nn.Embedding(seg_size,emb_size)
# I dont know why use norm to proces bedding data, I guess that it's help to convergence, perhaps.
self.norm=nn.LayerNorm(emb_size)
def forward(self,x,seg):
length=x.size(1)
pos=torch.arange(length)
return self.norm(self.emb_x(x)+self.emb_pos(pos)+self.emb_seg(seg))
class multi_head(nn.Module):
def __init__(self,
emb_size,
dk,
dv,
n_head,
):
super(multi_head, self).__init__()
self.Q=nn.Linear(emb_size,n_head*dk)
self.K=nn.Linear(emb_size,n_head*dk)
self.V=nn.Linear(emb_size,n_head*dv)
self.layer_norm=nn.LayerNorm(emb_size)
self.Linear=nn.Linear(n_head*dk,emb_size)
self.n_head=n_head
self.dk=dk
self.dv=dv
def dot_product_with_musk(self,query,key,value,mask):
dotproduct=torch.matmul(query,key.transpose(-1,-2))/np.sqrt(self.dk)
dotproduct=dotproduct.masked_fill_(mask,1e-9)
# size: batch_size, n_head,length,length
return torch.matmul(nn.Softmax(dim=-1)(dotproduct),value)
def forward(self,Input,mask):
residual=Input
batch_size=Input.size()[0]
# Size: batch_size, n_head,seq_length,dk(or dv)
K,Q,V=self.K(Input).view(batch_size,self.n_head,-1,self.dk),\
self.Q(Input).view(batch_size,self.n_head,-1,self.dk),\
self.V(Input).view(batch_size,self.n_head,-1,self.dv)
# transform original type of mask into multi-head type
try:
mask=mask.unsquieeze(1).repeat(1,self.n_head,1,1)
except:
mask = mask.data.unsqueeze(1).repeat(1, self.n_head, 1, 1)
context=self.dot_product_with_musk(query=Q,
key=K,
value=V,
mask=mask
)
context=context.transpose(1,2).contiguous().view(batch_size,-1,self.n_head*self.dk)
# now shape of context could be defined as : batch_size,length, n_head*length
output=self.Linear(context)
# finally return batch_size,length,emb_size
return self.layer_norm(output+residual)
class basic_block(nn.Module):
def __init__(self,emb_size,
dff,
dk,
dv,
n_head):
super(basic_block, self).__init__()
self.shit_forward=nn.Sequential(
nn.Linear(emb_size,dff),
nn.Linear(dff,emb_size)
)
self.multi_head=multi_head(emb_size,dk,dv,n_head)
def forward(self,Input,mask):
return self.shit_forward(grelu(self.multi_head(Input,mask)))
class Bert(nn.Module):
def __init__(self,
n_layers,
vocab_size,
emb_size,
max_len,
seg_size,
dff,
dk,
dv,
n_head,
n_class):
super(Bert, self).__init__()
self.vocab_size=vocab_size
self.emb_size=emb_size
self.emb_layer=Embedding(vocab_size,emb_size,max_len,seg_size)
self.encoder_layer=nn.Sequential(*[basic_block(emb_size,dff,dk,dv,n_head) for i in range(n_layers)])
self.fc1=nn.Sequential(
nn.Linear(emb_size, vocab_size),
nn.Dropout(0.5),
nn.Tanh(),
nn.Linear(vocab_size, n_class)
)
fc2=nn.Linear(emb_size, vocab_size)
fc2.weight=self.emb_layer.emb_x.weight
self.fc2=nn.Sequential(
nn.Linear(emb_size, emb_size),
Grelu(),
fc2
)
def get_mask(self,In):
batch_size,length,mask=In.size()[0],In.size()[1],In
mask=mask.eq(0).unsqueeze(1)
return mask.data.expand(batch_size,length,length)
def forward(self,x,seg,mask_):
mask=self.get_mask(x)
output=self.emb_layer(x=x,seg=seg)
for layer in self.encoder_layer:
output=layer(output,mask)
cls=self.fc1(output[:,0])
masked_pos = mask_[:, :, None].expand(-1, -1,self.emb_size)
masked=torch.gather(output,1,masked_pos)
logits=self.fc2(masked)
return logits,cls
def Train(self,epoches,criterion,optimizer,train_data_loader,use_gpu,device,
eval_data_loader=None,save_dir="./checkpoint",load_dir=None,save_freq=5,
):
import tqdm
if load_dir!=None:
if os.path.exists(load_dir):
checkpoint=torch.load(load_dir)
try:
self.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
except:
print("fail to load the state_dict")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for epc in range(epoches):
tq=tqdm.tqdm(train_data_loader)
for seq,(input_ids, segment_ids, masked_tokens, masked_pos, isNext) in enumerate(tq):
if use_gpu:
input_ids, segment_ids, masked_tokens, masked_pos, isNext=input_ids.to(device), \
segment_ids.to(device), \
masked_tokens.to(device), \
masked_pos.to(device),\
isNext.to(device)
logits_lm, logits_clsf = self(x=input_ids, seg=segment_ids, mask_=masked_pos)
loss_word = criterion(logits_lm.view(-1, self.vocab_size), masked_tokens.view(-1)) # for masked LM
loss_word = (loss_word.float()).mean()
loss_cls = criterion(logits_clsf, isNext) # for sentence classification
loss = loss_word + loss_cls
optimizer.zero_grad()
loss.backward()
optimizer.step()
tq.set_description(f"train Epoch {epc+1}, Batch{seq}")
tq.set_postfix(train_loss=loss)
if eval_data_loader!=None:
tq=tqdm.tqdm(eval_data_loader)
with torch.no_grad():
for seq,(input_ids, segment_ids, masked_tokens, masked_pos, isNext) in enumerate(tq):
input_ids, segment_ids, masked_tokens, masked_pos, isNext = input_ids.to(device), \
segment_ids.to(device), \
masked_tokens.to(device), \
masked_pos.to(device), \
isNext.to(device)
logits_lm, logits_clsf = self(x=input_ids, seg=segment_ids, mask_=masked_pos)
loss_word = criterion(logits_lm.view(-1, self.vocab_size),
masked_tokens.view(-1)) # for masked LM
loss_word = (loss_word.float()).mean()
loss_cls = criterion(logits_clsf, isNext) # for sentence classification
loss = loss_word + loss_cls
tq.set_description(f"eval Epoch {epc + 1}, Batch{seq}")
tq.set_postfix(train_loss=loss)
if (epc+1)%save_freq==0:
checkpoint = {'epoch': epc,
'best_loss': criterion,
'model': self.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(checkpoint, save_dir+ f"/checkpoint_{epc}.pth")
import configparser
# param
maxlen = "150"
batch_size = "6"
max_pred = "5"
n_layers = "6"
n_heads = "12"
d_model = "768"
d_ff = "768*4"
d_k = d_v = "64"
n_segments = "2"
random_seed="9527"
learning_rate="0.0001"
epoches="250"
#dir
data_dir="./data"
weight_dir="./checkpoint"
Config_file="Config.cfg"
# device
device="cuda:0"
default_tensor_type="torch.cuda.FloatTensor"
use_gpu="1"
# generate configs
conf=configparser.ConfigParser()
cfg=open(Config_file,"w")
conf.add_section("hyper_parameters")
conf.set("hyper_parameters","maxlen",maxlen)
conf.set("hyper_parameters","batch_size",batch_size)
conf.set("hyper_parameters","max_pred",max_pred)
conf.set("hyper_parameters","n_layers",n_layers)
conf.set("hyper_parameters","n_heads",n_heads)
conf.set("hyper_parameters","d_model",d_model)
conf.set("hyper_parameters","d_ff",d_ff)
conf.set("hyper_parameters","d_k",d_k)
conf.set("hyper_parameters","d_v",d_v )
conf.set("hyper_parameters","n_segments",n_segments)
conf.set("hyper_parameters","seed",random_seed)
conf.set("hyper_parameters","lr",learning_rate)
conf.set("hyper_parameters","epc",epoches)
conf.add_section("file_dir")
conf.set("file_dir","data",data_dir)
conf.set("file_dir","data",weight_dir)
conf.add_section("device")
conf.set("device","force",device)
conf.set("device","tensor_dtype",default_tensor_type)
conf.set("device","confirm_gpu",use_gpu)
conf.write(cfg)
cfg.close()
\ No newline at end of file
The MIT License (MIT)
Copyright (c) 2016 JackeyGao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
<p align="center">
<a href="https://github.com/chinese-poetry/chinese-poetry">
<img src="https://avatars3.githubusercontent.com/u/30764933?s=200&v=4" alt="chinese-poetry">
</a>
</p>
<p align="center">阿里招p6/p7 Python Golang | gaojunqi@outlook.com | 上海张江 </p>
<h2 align="center">chinese-poetry: 最全中文诗歌古典文集数据库</h2>
<p align="center">
<a href="https://travis-ci.com/chinese-poetry/chinese-poetry" rel="nofollow">
<img height="28px" alt="Build Status" src="https://img.shields.io/travis/chinese-poetry/chinese-poetry?style=for-the-badge" style="max-width:100%;">
</a>
<a href="https://github.com/chinese-poetry/chinese-poetry/blob/master/LICENSE">
<img height="28px" alt="License" src="http://img.shields.io/badge/license-mit-blue.svg?style=for-the-badge" style="max-width:100%;">
</a>
<a href="https://github.com/chinese-poetry/chinese-poetry/graphs/contributors">
<img height="28px" alt="Contributors" src="https://img.shields.io/github/contributors/chinese-poetry/chinese-poetry.svg?style=for-the-badge" style="max-width:100%;">
</a>
<a href="https://www.patreon.com/jackeygao" rel="nofollow">
<img height="28px" alt="Patreon" src="https://img.shields.io/endpoint.svg?url=https%3A%2F%2Fshieldsio-patreon.herokuapp.com%2Fjackeygao%2Fpledges&amp;style=for-the-badge" style="max-width:100%;">
</a>
<a href="https://github.com/chinese-poetry/chinese-poetry/" rel="nofollow">
<img alt="HitCount" height="28px" src="http://hits.dwyl.io/chinese-poetry/chinese-poetry.svg" style="max-width:100%;">
</a>
</p>
最全的中华古典文集数据库,包含 5.5 万首唐诗、26 万首宋诗、2.1 万首宋词和其他古典文集。诗人包括唐宋两朝近 1.4 万古诗人,和两宋时期 1.5 千古词人。数据来源于互联网。
**为什么要做这个仓库?** 古诗是中华民族乃至全世界的瑰宝,我们应该传承下去,虽然有古典文集,但大多数人并没有拥有这些书籍。从某种意义上来说,这些庞大的文集离我们是有一定距离的。而电子版方便拷贝,所以此开源数据库诞生了。此数据库通过 JSON 格式分发,可以让你很方便的开始你的项目。
古诗采集没有记录过程,因为古诗数据庞大,目标网站有限制,采集过程经常中断超过了一个星期。2017 年新加入全宋词,[全宋词爬取过程及数据分析](https://jackeygao.github.io/r/words/crawl-ci.html)
## 高频词分析图
<details open>
<summary><b>宋词受欢迎的词牌名</b></summary>
<div align="center">
<img src="https://raw.githubusercontent.com/jackeygao/chinese-poetry/master/images/ci_rhythmic_topK.png" alt="两宋喜欢的词牌名">
</div>
</details>
<details>
<summary><b>宋词高频词</b></summary>
<img src="https://raw.githubusercontent.com/jackeygao/chinese-poetry/master/images/ci_words_topK.png" alt="宋词高频词" style="max-width:100%;">
</details>
<details>
<summary><b>宋词作者作品榜</b></summary>
<img src="https://raw.githubusercontent.com/jackeygao/chinese-poetry/master/images/ci_author_topK.png" alt="宋词作者作品榜" style="max-width:100%;">
</details>
<details>
<summary><b>唐诗高频词</b></summary>
<img src="https://raw.githubusercontent.com/jackeygao/chinese-poetry/master/images/tang_text_topK.png" alt="唐诗高频词" style="max-width:100%;">
</details>
<details>
<summary><b>唐诗作者作品榜</b></summary>
<img src="https://raw.githubusercontent.com/jackeygao/chinese-poetry/master/images/tang_author_topK.png" alt="唐诗作者作品榜" style="max-width:100%;">
</details>
<details>
<summary><b>宋诗高频词</b></summary>
<img src="https://raw.githubusercontent.com/jackeygao/chinese-poetry/master/images/song_text_topK.png" alt="宋诗高频词" style="max-width:100%;">
</details>
<details>
<summary><b>宋诗作者作品榜</b></summary>
<img src="https://raw.githubusercontent.com/jackeygao/chinese-poetry/master/images/song_author_topK.png" alt="宋诗作者作品榜" style="max-width:100%;">
</details>
## 数据集
- 全唐诗 [json](./json)
- 全宋诗 [json](./json)
- 全宋词 [ci](./ci)
- 五代·花间集 [wudai/huajianji](./wudai/huajianji)
- 五代·南唐二主词 [wudai/nantan](./wudai/nantang)
- 论语 [lunyu](./lunyu)
- 诗经 [shijing](./shijing)
- 幽梦影 [youmengying](./youmengying)
- 四书五经 [sishuwujing](./sishuwujing)
- 蒙學 [mengxue](./mengxue)
## 贡献
本项目目的是借助技术来生成格式化(JSON)数据,让开发者更方便快速的构建诗词类应用程序。身单力薄,欢迎更多人来维护,你可以通过以下方法来参与贡献:
- 直接提交 PR 或者通过 issue 讨论来优化完善此数据库,理论上古诗歌体非宗教类都欢迎加入,部分有争议性的数据需要社区投票讨论决定是否加入。关于诗句的纠错在创建 PR 时请标明出处。更多规范请[参考贡献规范文档](https://github.com/chinese-poetry/chinese-poetry/wiki/%E5%8F%82%E4%B8%8E%E8%B4%A1%E7%8C%AE%E8%A7%84%E8%8C%83)
- 如果你没有办法直接参与完善的过程,你也可以通过 「[Patreon 周期性赞助](https://www.patreon.com/jackeygao)」的形式来持续帮助并激励我去优化完善此数据库。如果您不喜欢周期性赞助,你也可以通过「[支付宝](https://github.com/jackeyGao/JackeyGao.github.io/blob/master/static/images/alipay.png)」或者「[微信赞赏码](https://github.com/jackeyGao/JackeyGao.github.io/blob/master/static/images/wechat.jpg)」进行一次性赞助(备注留下邮箱)。
- 如有建议或吐槽,欢迎联系我的邮箱 gaojunqi@outlook.com。
无论通过哪种形式贡献最终都会使之变得更好!
### 赞助者
[上海逆行信息科技](http://www.desmix.com/)
### 贡献者
<p align="center">
<img src="https://opencollective.com/chinese-poetry/contributors.svg?width=890&button=false" alt="Contributors">
</p>
## 案例展示
<details>
<summary>案例展示</summary>
- [中文诗歌主页](https://shici.store)是一个基于浏览器的诗词网站,包含唐诗三百首、宋词三百首等文集。
- [animalize](https://github.com/animalize) **/** [QuanTangshi](https://github.com/animalize/QuanTangshi) *离线全唐诗 Android*
- [justdark](https://github.com/justdark) **/** [pytorch-poetry-gen](https://github.com/justdark/pytorch-poetry-gen) *a char-RNN based on pytorch*
- [Clover27](https://github.com/Clover27) **/** [ancient-Chinese-poem-generator](https://github.com/Clover27/ancient-Chinese-poem-generator) *Ancient-Chinese-Poem-Generator*
- [chinese-poetry](https://github.com/chinese-poetry) **/** [poetry-calendar](http://shici.store/poetry-calendar/) *诗词周历*
- [chenyuntc](https://github.com/chenyuntc) **/** [pytorch-book](https://github.com/chenyuntc/pytorch-book/blob/master/chapter9-神经网络写诗(CharRNN)/) *简体唐诗生成(char-RNN),可生成藏头诗,自定义诗歌意境,前缀等*
- [okcy1016](https://github.com/okcy1016) **/** [poetry-desktop](https://github.com/okcy1016/poetry-desktop/) *诗词桌面*
- [huangjianke](https://github.com/huangjianke) **/** [weapp-poem](https://github.com/huangjianke/weapp-poem/) *诗词墨客 小程序版*
- [汉字之美](https://hz.xusenlin.com/)汉字之美是一个方便查询的诗词网站,简洁干净,方便使用。
</details>
## License
[MIT](https://github.com/chinese-poetry/chinese-poetry/blob/master/LICENSE) 许可证。
此差异已折叠。
Hello, what is your name.
Hello, My name is little cute. nice to meet you!
Nice meet you too. So please introduce yourself in a short word, thank you!
Great. My father is Zzq, he is from China, a really kind person.
That sounds good!
Thank you.
Emmm, I want to learn more about him, mainly about his history.
He is majored both in bio science and computer science. Last year, he just get the offer from the top university in his country!
his dream is to combine computer science and neuron science to realize real AI. Emmmm, maybe that will be one of my brothers, who knows.
What a amazing dream! congratulation!
Thank you, have a nice day!
You too!
\ No newline at end of file
# encode utf-8
# code by steve zhang z
# Time: 4/22/2021
# electric address: stevezhangz@163.com
import re
from random import *
import torch.utils.data as Data
import os
import json
import thulac
import numpy as np
class general_transform_text2list:
"""
notification: All series of data process method here only support the list type sentences, so whether json or txt file
should be transformed into list type, such as [s1,s2,s3,s4,s5]
"""
def __init__(self,text_dir,args=[],type="json"):
assert type in ["json","txt"], print("plz give a txt or a json")
self.text=text_dir
self.type=type
self.arg=args
def getdata(self):
if self.type=="json":
return self.for_json()
elif self.type=="txt":
return self.for_txt()
else:
raise KeyError
def for_txt(self):
sentences=[]
if os.path.exists(self.text):
with open(self.text,"r") as f:
for i in f:
sentences.append(i)
return sentences
def for_json(self):
sentences=[]
keys=[i for i in self.arg]
with open(self.text,"r") as f:
f=json.load(f)
for i in f:
for key in keys:
if isinstance(i[key],list):
for j in i[key]:
sentences.append(j)
elif isinstance(i[key],str):
sentences.append(i[key])
return sentences
class generate_vocab_normalway:
"""
:notification: before using this method please transform the texts into the form of [s1,s2,s3,s4,s5,s6....], which
"s" represents a individual sentence.
:param: 1. text_list(transformed text files)
2. map_dir(throughout this method, u will obtained the bijection between words and its ids, all of them saved
in the map_dir, so this method have to update it when process different tasks)(except that, the map_dir
contains three keys: words, word2idx and idx2word, which correspond to lib of words, map from word to id as
well as map from id to word)
"""
def __init__(self,text_list,map_dir,record_update=True,language="Chinese"):
self.text_list=text_list
self.map_dir=map_dir
self.language=language
self.update=record_update
def transform(self):
if os.path.exists(self.map_dir):
with open(self.map_dir, "r") as file_:
map_file = json.load(file_)
use_before = 1
file_.close()
else:
use_before = 0
cut = thulac.thulac()
if use_before:
words = map_file["words"]
else:
words = []
sentences = []
for i in self.text_list:
sentence = re.sub("[.,!?,。::\n\\-]", '', i.lower())
if self.language=="Chinese":
sentence = list(set(np.array(cut.cut(sentence, text=False))[:, 0]))
elif self.language=="English":
sentence=list(set(sentence.split(" ")))
sentences.append(sentence)
for i in sentence:
if use_before:
if i not in map_file["words"]:
words.append(i)
else:
words.append(i)
words = list(set(words))
if use_before:
word2idx = map_file["word2idx"]
else:
word2idx = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for seq, val in enumerate(words):
word2idx[val] = seq + 4
id_sentence = []
for i in sentences:
id_sentence.append([word2idx[j] for j in i])
vocab_size = len(word2idx)
idx2word = {i: w for i, w in enumerate(word2idx)}
if self.update:
if use_before:
map_file["word2idx"] = word2idx
map_file["idx2word"] = idx2word
map_file["words"] = words
with open(self.map_dir, "w") as file_:
json.dump(map_file, file_)
file_.close()
else:
if self.map_dir == None:
map_dir = "word_info.json"
map_file = {}
map_file["word2idx"] = word2idx
map_file["idx2word"] = idx2word
map_file["words"] = words
with open(map_dir, "w") as f:
json.dump(map_file, f)
return sentences, id_sentence, idx2word, word2idx, vocab_size
def generate_vocab_from_poem_chuci(poem_dir,map_dir):
"""
:poem introduction: This poem was written by Qu Yuan, a great poet in ancient China
:data link, Thanks: https://codechina.csdn.net/mirrors/chinese-poetry/chinese-poetry?utm_source=csdn_github_accelerator
:param poem_dir: data/chinese-poetry/chuci/chuci.json
"""
if os.path.exists(map_dir):
with open(map_dir,"r") as file_:
map_file=json.load(file_)
use_before=1
file_.close()
else:
use_before=0
cut=thulac.thulac()
if not os.path.exists(poem_dir):
raise FileNotFoundError
else:
with open(poem_dir, "r") as f:
json_file = json.load(f)
if use_before:
words=map_file["words"]
else:
words=[]
sentences=[]
for poem in json_file:
for i in poem["content"]:
sentence= re.sub("[.,!?,。::\\-]", '', i.lower())
sentence=list(set(np.array(cut.cut(sentence,text=False))[:,0]))
sentences.append(sentence)
for i in sentence:
if use_before:
if i not in map_file["words"]:
words.append(i)
else:
words.append(i)
words=list(set(words))
if use_before:
word2idx =map_file["word2idx"]
else:
word2idx = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for seq,val in enumerate(words):
word2idx[val]=seq+4
id_sentence=[]
for i in sentences:
id_sentence.append([word2idx[j] for j in i])
vocab_size=len(word2idx)
idx2word={i:w for i,w in enumerate(word2idx)}
if use_before:
map_file["word2idx"]=word2idx
map_file["idx2word"]=idx2word
map_file["words"]=words
with open(map_dir, "w") as file_:
json.dump(map_file,file_)
file_.close()
else:
if map_dir==None:
map_dir="word_info.json"
map_file={}
map_file["word2idx"] = word2idx
map_file["idx2word"] = idx2word
map_file["words"] = words
with open(map_dir,"w") as f:
json.dump(map_file,f)
return sentences,id_sentence,idx2word,word2idx,vocab_size
def creat_batch(batch_size,max_pred,maxlen,vocab_size,word2idx,token_list,sentences):
batch = []
positive = negative = 0
while positive != batch_size / 2 or negative != batch_size / 2:
tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(
len(sentences))
tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
cand_maked_pos = [i for i, token in enumerate(input_ids)
if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]
shuffle(cand_maked_pos)
masked_tokens, masked_pos = [], []
for pos in cand_maked_pos[:n_pred]:
masked_pos.append(pos)
masked_tokens.append(input_ids[pos])
if random() < 0.8:
input_ids[pos] = word2idx['[MASK]']
elif random() > 0.9:
index = randint(0, vocab_size - 1)
while index < 4:
index = randint(0, vocab_size - 1)
input_ids[pos] = index
n_pad = maxlen - len(input_ids)
input_ids.extend([0] * n_pad)
segment_ids.extend([0] * n_pad)
if max_pred > n_pred:
n_pad = max_pred - n_pred
masked_tokens.extend([0] * n_pad)
masked_pos.extend([0] * n_pad)
if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
positive += 1
elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
negative += 1
return batch
class Text_file(Data.Dataset):
def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
self.input_ids = input_ids
self.segment_ids = segment_ids
self.masked_tokens = masked_tokens
self.masked_pos = masked_pos
self.isNext = isNext
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[
idx]
\ No newline at end of file
absl-py==0.11.0
altgraph==0.17
astor==0.8.1
astunparse==1.6.3
Babel==2.9.0
bce-python-sdk==0.8.59
cachetools==4.2.1
certifi==2020.12.5
cfgv==3.2.0
chardet==4.0.0
configparser==5.0.2
cycler==0.10.0
decorator==5.0.5
flake8==3.9.0
Flask==1.1.2
Flask-Babel==2.0.0
flatbuffers==1.12
gast==0.3.3
gedit==0.0.2
genpac==2.1.0
google-auth==1.24.0
google-auth-oauthlib==0.4.2
google-pasta==0.2.0
graphviz==0.16
grpcio==1.32.0
h5py==2.10.0
identify==2.2.3
idna==2.10
install==1.3.4
itsdangerous==1.1.0
jupyterthemes==0.20.0
kaggle==1.5.12
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
lesscpy==0.14.0
Markdown==3.3.3
matplotlib==3.3.4
mccabe==0.6.1
nodeenv==1.6.0
numpy==1.19.5
oauthlib==3.1.0
ofa==0.1.0.post202012082159
opt-einsum==3.3.0
paddlepaddle-gpu==2.0.1.post110
Pillow==8.1.0
ply==3.11
pre-commit==2.12.0
protobuf==3.14.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycodestyle==2.7.0
pycryptodome==3.10.1
pydot==1.4.1
pyflakes==2.3.1
Pygments==2.8.1
pyinstaller==4.2
pyinstaller-hooks-contrib==2021.1
pyparsing==2.4.7
python-dateutil==2.8.1
python-slugify==4.0.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7
shellcheck-py==0.7.1.1
six==1.15.0
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorboardX==2.1
tensorflow-estimator==2.4.0
tensorflow-gpu==2.4.1
termcolor==1.1.0
text-unidecode==1.3
thulac==0.2.1
toml==0.10.2
torch==1.8.0+cu111
torchaudio==0.8.0
torchsummary==1.5.1
torchvision==0.9.0+cu111
touch==2020.12.3
tqdm==4.59.0
typing-extensions==3.7.4.3
urllib3==1.26.3
values==2020.12.3
visualdl==2.1.1
Werkzeug==1.0.1
wrapt==1.12.1
# encode utf-8
# code by steve zhang z
# Time: 4/22/2021
# electric address: stevezhangz@163.com
from bert import *
import torch
import torch.nn as nn
import torch.optim as optim
from Config_load import *
from data_process import *
np.random.seed(random_seed)
# transform json to list
#json2list=general_transform_text2list("data/demo.txt",type="txt")
json2list=general_transform_text2list("data/chinese-poetry/chuci/chuci.json",type="json",args=['content'])
data=json2list.getdata()
# transform list to token
list2token=generate_vocab_normalway(data,map_dir="words_info.json")
sentences,token_list,idx2word,word2idx,vocab_size=list2token.transform()
batch = creat_batch(batch_size,max_pred,maxlen,vocab_size,word2idx,token_list,sentences)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
torch.LongTensor(masked_pos), torch.LongTensor(isNext)
loader = Data.DataLoader(Text_file(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)
model=Bert(n_layers=n_layers,
vocab_size=vocab_size,
emb_size=d_model,
max_len=maxlen,
seg_size=n_segments,
dff=d_ff,
dk=d_k,
dv=d_v,
n_head=n_heads,
n_class=2,
)
if use_gpu:
with torch.cuda.device(device) as device:
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=lr)
model.Train(epoches=epoches,
train_data_loader=loader,
optimizer=optimizer,
criterion=criterion,
save_dir=weight_dir,
save_freq=100,
load_dir="checkpoint/checkpoint_199.pth",
use_gpu=use_gpu,
device=device
)
else:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=lr)
model.Train(epoches=epoches,
train_data_loader=loader,
optimizer=optimizer,
criterion=criterion,
save_dir=weight_dir,
save_freq=50,
load_dir="checkpoint/checkpoint_199.pth",
use_gpu=use_gpu,
device=device
)
\ No newline at end of file
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册