wdtypes.py 1.8 KB
Newer Older
J
jrzaurin 已提交
1
import sys
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
from types import SimpleNamespace
from typing import (
    Any,
    Dict,
    List,
    Match,
    Tuple,
    Union,
    Callable,
    Iterable,
    Iterator,
    Optional,
    Generator,
    Collection,
)
from pathlib import PosixPath
18

19
import torch
20
from torch import Tensor
21 22
from torch.nn import Module
from torch.optim.optimizer import Optimizer
J
jrzaurin 已提交
23
from torchvision.transforms import (
24 25 26 27
    Pad,
    Scale,
    Lambda,
    Resize,
J
jrzaurin 已提交
28
    Compose,
29
    TenCrop,
J
jrzaurin 已提交
30
    FiveCrop,
31
    ToTensor,
J
jrzaurin 已提交
32 33
    Grayscale,
    Normalize,
34
    CenterCrop,
J
jrzaurin 已提交
35
    RandomCrop,
36 37 38
    ToPILImage,
    ColorJitter,
    RandomApply,
J
jrzaurin 已提交
39
    RandomOrder,
40 41
    RandomAffine,
    RandomChoice,
J
jrzaurin 已提交
42
    RandomRotation,
43
    RandomGrayscale,
J
jrzaurin 已提交
44
    RandomSizedCrop,
45
    RandomResizedCrop,
J
jrzaurin 已提交
46
    RandomVerticalFlip,
47 48
    LinearTransformation,
    RandomHorizontalFlip,
J
jrzaurin 已提交
49
)
50
from torch.optim.lr_scheduler import _LRScheduler
51
from torch.utils.data.dataloader import DataLoader
52

53
from pytorch_widedeep.models import WideDeep
54
from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax
55
from pytorch_widedeep.models.transformers.layers import FullEmbeddingDropout
56

J
jrzaurin 已提交
57
ListRules = Collection[Callable[[str], str]]
J
jrzaurin 已提交
58
Tokens = Collection[Collection[str]]
J
jrzaurin 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
Transforms = Union[
    CenterCrop,
    ColorJitter,
    Compose,
    FiveCrop,
    Grayscale,
    Lambda,
    LinearTransformation,
    Normalize,
    Pad,
    RandomAffine,
    RandomApply,
    RandomChoice,
    RandomCrop,
    RandomGrayscale,
    RandomHorizontalFlip,
    RandomOrder,
    RandomResizedCrop,
    RandomRotation,
    RandomSizedCrop,
    RandomVerticalFlip,
    Resize,
    Scale,
    TenCrop,
    ToPILImage,
    ToTensor,
]
86
LRScheduler = _LRScheduler
J
jrzaurin 已提交
87
ModelParams = Generator[Tensor, Tensor, Tensor]
88
NormLayers = Union[torch.nn.LayerNorm, torch.nn.BatchNorm1d]
89
DropoutLayers = Union[torch.nn.Dropout, FullEmbeddingDropout]