提交 c90da403 编写于 作者: W weixin_44417093

Sun Dec 3 15:23:00 CST 2023 inscode

上级 37d51ffd
run = "pip install -r requirements.txt;python main.py"
language = "python"
[packager]
AUTO_PIP = true
......@@ -9,4 +10,7 @@ PATH = "${VIRTUAL_ENV}/bin:${PATH}"
PYTHONPATH = "$PYTHONHOME/lib/python3.10:${VIRTUAL_ENV}/lib/python3.10/site-packages"
REPLIT_POETRY_PYPI_REPOSITORY = "http://mirrors.csdn.net.cn/repository/csdn-pypi-mirrors/simple"
MPLBACKEND = "TkAgg"
POETRY_CACHE_DIR = "/root/${PROJECT_DIR}/.cache/pypoetry"
\ No newline at end of file
POETRY_CACHE_DIR = "/root/${PROJECT_DIR}/.cache/pypoetry"
[debugger]
program = "main.py"
......@@ -21,11 +21,8 @@ class MyDataset(Dataset):
def __getitem__(self, idx):
tokens = self.data[idx].split()
tokens = [self.tokenizer.get(token, self.tokenizer['<UNK>']) for token in tokens]
if len(tokens) < self.max_len:
tokens += [self.tokenizer['<PAD>']] * (self.max_len - len(tokens))
else:
tokens = tokens[:self.max_len]
return torch.tensor(tokens)
tokens += [self.tokenizer['<PAD>']] * (self.max_len - len(tokens))
return torch.tensor(tokens[:self.max_len])
dataset = MyDataset(data, tokenizer, max_len)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
......@@ -62,7 +59,7 @@ for epoch in range(10):
# 测试模型
test_data = ['hello', 'goodbye', 'pytorch']
test_tokens = [[tokenizer.get(token, tokenizer['<UNK>']) for token in data.split()] for data in test_data]
test_tokens = [torch.tensor(tokens) for tokens in test_tokens]
test_tokens = [torch.tensor(tokens[:max_len]) for tokens in test_tokens]
test_output = model(torch.stack(test_tokens))
test_pred = torch.argmax(test_output, dim=-1)
for i, data in enumerate(test_data):
......@@ -71,7 +68,7 @@ for i, data in enumerate(test_data):
# 应用模型
input_data = 'hello'
input_tokens = [tokenizer.get(token, tokenizer['<UNK>']) for token in input_data.split()]
input_tokens = torch.tensor(input_tokens).unsqueeze(0)
input_tokens = torch.tensor(input_tokens[:max_len]).unsqueeze(0)
output = model(input_tokens)
pred = torch.argmax(output, dim=-1)
print(' '.join([k for k, v in tokenizer.items() if v == pred.item()]))
\ No newline at end of file
print(' '.join([k for k, v in tokenizer.items() if v == pred.item()]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册