From c90da4038a32eccc247ef77b2d2758fcd76a0089 Mon Sep 17 00:00:00 2001 From: weixin_44417093 Date: Sun, 3 Dec 2023 15:23:00 +0800 Subject: [PATCH] Sun Dec 3 15:23:00 CST 2023 inscode --- .inscode | 6 +++++- main.py | 13 +++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.inscode b/.inscode index 7d17616..3a001eb 100644 --- a/.inscode +++ b/.inscode @@ -1,4 +1,5 @@ run = "pip install -r requirements.txt;python main.py" +language = "python" [packager] AUTO_PIP = true @@ -9,4 +10,7 @@ PATH = "${VIRTUAL_ENV}/bin:${PATH}" PYTHONPATH = "$PYTHONHOME/lib/python3.10:${VIRTUAL_ENV}/lib/python3.10/site-packages" REPLIT_POETRY_PYPI_REPOSITORY = "http://mirrors.csdn.net.cn/repository/csdn-pypi-mirrors/simple" MPLBACKEND = "TkAgg" -POETRY_CACHE_DIR = "/root/${PROJECT_DIR}/.cache/pypoetry" \ No newline at end of file +POETRY_CACHE_DIR = "/root/${PROJECT_DIR}/.cache/pypoetry" + +[debugger] +program = "main.py" diff --git a/main.py b/main.py index eec8511..f0a0494 100644 --- a/main.py +++ b/main.py @@ -21,11 +21,8 @@ class MyDataset(Dataset): def __getitem__(self, idx): tokens = self.data[idx].split() tokens = [self.tokenizer.get(token, self.tokenizer['']) for token in tokens] - if len(tokens) < self.max_len: - tokens += [self.tokenizer['']] * (self.max_len - len(tokens)) - else: - tokens = tokens[:self.max_len] - return torch.tensor(tokens) + tokens += [self.tokenizer['']] * (self.max_len - len(tokens)) + return torch.tensor(tokens[:self.max_len]) dataset = MyDataset(data, tokenizer, max_len) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) @@ -62,7 +59,7 @@ for epoch in range(10): # 测试模型 test_data = ['hello', 'goodbye', 'pytorch'] test_tokens = [[tokenizer.get(token, tokenizer['']) for token in data.split()] for data in test_data] -test_tokens = [torch.tensor(tokens) for tokens in test_tokens] +test_tokens = [torch.tensor(tokens[:max_len]) for tokens in test_tokens] test_output = model(torch.stack(test_tokens)) test_pred = torch.argmax(test_output, dim=-1) for i, data in enumerate(test_data): @@ -71,7 +68,7 @@ for i, data in enumerate(test_data): # 应用模型 input_data = 'hello' input_tokens = [tokenizer.get(token, tokenizer['']) for token in input_data.split()] -input_tokens = torch.tensor(input_tokens).unsqueeze(0) +input_tokens = torch.tensor(input_tokens[:max_len]).unsqueeze(0) output = model(input_tokens) pred = torch.argmax(output, dim=-1) -print(' '.join([k for k, v in tokenizer.items() if v == pred.item()])) \ No newline at end of file +print(' '.join([k for k, v in tokenizer.items() if v == pred.item()])) -- GitLab