提交 0bfb2106 编写于 作者: V Varuna Jayasiri

fix vit pe

上级 4cf1d74e
......@@ -39,7 +39,7 @@ Here's [an experiment](experiment.html) that trains ViT on CIFAR-10.
This doesn't do very well because it's trained on a small dataset.
It's a simple experiment that anyone can run and play with ViTs.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/afdd5332188b11edbdf543360515b595)
"""
import torch
......@@ -114,7 +114,7 @@ class LearnedPositionalEmbeddings(Module):
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
"""
# Get the positional embeddings for the given patches
pe = self.positional_encodings[x.shape[0]]
pe = self.positional_encodings[:x.shape[0]]
# Add to patch embeddings and return
return x + pe
......
......@@ -7,7 +7,7 @@ summary: >
# Train a [Vision Transformer (ViT)](index.html) on CIFAR 10
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/afdd5332188b11edbdf543360515b595)
"""
from labml import experiment
......@@ -76,7 +76,7 @@ def main():
'transformer.d_model': 512,
# Training epochs and batch size
'epochs': 1000,
'epochs': 32,
'train_batch_size': 64,
# Augment CIFAR 10 images for training
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册