diff --git a/ding/torch_utils/network/__init__.py b/ding/torch_utils/network/__init__.py index 9e6d84af87e4464c412ff27c8b7a92b98f04acb4..4bb702b223801c0fb894c6fee00d53f251c18906 100644 --- a/ding/torch_utils/network/__init__.py +++ b/ding/torch_utils/network/__init__.py @@ -8,3 +8,4 @@ from .soft_argmax import SoftArgmax from .transformer import Transformer from .scatter_connection import ScatterConnection from .resnet import resnet18, ResNet +from .gumbel_softmax import GumbelSoftmax diff --git a/ding/torch_utils/network/gumbel_softmax.py b/ding/torch_utils/network/gumbel_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6177c1797929e6a906f5c8d01ef315e1e4cfd7 --- /dev/null +++ b/ding/torch_utils/network/gumbel_softmax.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GumbelSoftmax(nn.Module): + r""" + Overview: + An nn.Module that computes GumbelSoftmax + Interface: + __init__, forward + + .. note: + + For more gumbelsoftmax info, you can refer to + the paper + + """ + + def __init__(self) -> None: + r""" + Overview: + Initialize the GumbelSoftmax module + """ + super(GumbelSoftmax, self).__init__() + + def gumbel_softmax_sample(self, x: torch.Tensor, temperature, eps=1e-8): + """ Draw a sample from GumbelSoftmax distribution""" + U = torch.rand(x.shape) + U = U.to(x.device) + y = x - torch.log(-torch.log(U + eps) + eps) + return F.softmax(y / temperature, dim=1) + + def forward(self, x: torch.Tensor, temperature: float = 1.0, hard: bool = False) -> torch.Tensor: + r""" + Arguments: + - x (:obj:`torch.Tensor`): unnormalized log-probs + - temperature(:obj:`float`): non-negative scalar + - hard(:obj:`bool`): if true return one-hot label + Returns: + - output (:obj:`torch.Tensor`): sample from GumbelSoftmax distribution + Shapes: + - x: :math:`(B, N)`, while B is the batch size, N is number of classes + - output: :math:`(B, N)`, while B is the batch size, N is number of classes + """ + y = self.gumbel_softmax_sample(x, temperature) + if hard: + y_hard = torch.zeros_like(x) + y_hard[torch.arange(0, x.shape[0]), y.max(1)[1]] = 1 + # The detach function treat (y_hard - y) as constant, + # to make sure makes the gradient equal to y_soft gradient + y = (y_hard - y).detach() + y + return y diff --git a/ding/torch_utils/network/tests/test_gumbel_softmax.py b/ding/torch_utils/network/tests/test_gumbel_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..9168f06a594d73305f6602c5ddf01985d5493377 --- /dev/null +++ b/ding/torch_utils/network/tests/test_gumbel_softmax.py @@ -0,0 +1,26 @@ +import numpy as np +import pytest +import torch + +from ding.torch_utils.network import GumbelSoftmax, gumbel_softmax + + +@pytest.mark.unittest +class TestGumbelSoftmax: + + def test(self): + B = 4 + N = 10 + model = GumbelSoftmax() + # data case 1 + for _ in range(N): + data = torch.rand((4, 10)) + data = torch.log(data) + gumbelsoftmax = model(data, hard=False) + assert gumbelsoftmax.shape == (B, N) + # data case 2 + for _ in range(N): + data = torch.rand((4, 10)) + data = torch.log(data) + gumbelsoftmax = model(data, hard=True) + assert gumbelsoftmax.shape == (B, N)