提交 37d51ffd 编写于 作者: 飞 羽's avatar 飞 羽

生产代码片段

上级 2eebcf6c
print('欢迎来到 InsCode')
\ No newline at end of file
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# 准备数据集
data = ['hello world', 'goodbye world', 'hello pytorch', 'goodbye pytorch']
tokenizer = {'<PAD>': 0, '<UNK>': 1, 'hello': 2, 'world': 3, 'goodbye': 4, 'pytorch': 5}
max_len = 3
# 数据预处理
class MyDataset(Dataset):
def __init__(self, data, tokenizer, max_len):
self.data = data
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
tokens = self.data[idx].split()
tokens = [self.tokenizer.get(token, self.tokenizer['<UNK>']) for token in tokens]
if len(tokens) < self.max_len:
tokens += [self.tokenizer['<PAD>']] * (self.max_len - len(tokens))
else:
tokens = tokens[:self.max_len]
return torch.tensor(tokens)
dataset = MyDataset(data, tokenizer, max_len)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 构建模型
class MyModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(MyModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
out, _ = self.lstm(x)
out = self.fc(out)
return out
model = MyModel(len(tokenizer), 10, 20)
# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(10):
for batch in dataloader:
optimizer.zero_grad()
output = model(batch)
loss = criterion(output.view(-1, len(tokenizer)), batch.view(-1))
loss.backward()
optimizer.step()
print('Epoch: {}, Loss: {:.4f}'.format(epoch+1, loss.item()))
# 测试模型
test_data = ['hello', 'goodbye', 'pytorch']
test_tokens = [[tokenizer.get(token, tokenizer['<UNK>']) for token in data.split()] for data in test_data]
test_tokens = [torch.tensor(tokens) for tokens in test_tokens]
test_output = model(torch.stack(test_tokens))
test_pred = torch.argmax(test_output, dim=-1)
for i, data in enumerate(test_data):
print('{} -> {}'.format(data, ' '.join([k for k, v in tokenizer.items() if v == test_pred[i].item()])))
# 应用模型
input_data = 'hello'
input_tokens = [tokenizer.get(token, tokenizer['<UNK>']) for token in input_data.split()]
input_tokens = torch.tensor(input_tokens).unsqueeze(0)
output = model(input_tokens)
pred = torch.argmax(output, dim=-1)
print(' '.join([k for k, v in tokenizer.items() if v == pred.item()]))
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册