未验证 提交 58084df3 编写于 作者: Q qiuqiu 提交者: GitHub

feature(gg15): add gumbel softmax (#169)

* feature(gg15): add gumbel softmax

* add ww

* fix format
上级 118cc673
......@@ -8,3 +8,4 @@ from .soft_argmax import SoftArgmax
from .transformer import Transformer
from .scatter_connection import ScatterConnection
from .resnet import resnet18, ResNet
from .gumbel_softmax import GumbelSoftmax
import torch
import torch.nn as nn
import torch.nn.functional as F
class GumbelSoftmax(nn.Module):
r"""
Overview:
An nn.Module that computes GumbelSoftmax
Interface:
__init__, forward
.. note:
For more gumbelsoftmax info, you can refer to
the paper <https://arxiv.org/abs/1611.01144>
"""
def __init__(self) -> None:
r"""
Overview:
Initialize the GumbelSoftmax module
"""
super(GumbelSoftmax, self).__init__()
def gumbel_softmax_sample(self, x: torch.Tensor, temperature, eps=1e-8):
""" Draw a sample from GumbelSoftmax distribution"""
U = torch.rand(x.shape)
U = U.to(x.device)
y = x - torch.log(-torch.log(U + eps) + eps)
return F.softmax(y / temperature, dim=1)
def forward(self, x: torch.Tensor, temperature: float = 1.0, hard: bool = False) -> torch.Tensor:
r"""
Arguments:
- x (:obj:`torch.Tensor`): unnormalized log-probs
- temperature(:obj:`float`): non-negative scalar
- hard(:obj:`bool`): if true return one-hot label
Returns:
- output (:obj:`torch.Tensor`): sample from GumbelSoftmax distribution
Shapes:
- x: :math:`(B, N)`, while B is the batch size, N is number of classes
- output: :math:`(B, N)`, while B is the batch size, N is number of classes
"""
y = self.gumbel_softmax_sample(x, temperature)
if hard:
y_hard = torch.zeros_like(x)
y_hard[torch.arange(0, x.shape[0]), y.max(1)[1]] = 1
# The detach function treat (y_hard - y) as constant,
# to make sure makes the gradient equal to y_soft gradient
y = (y_hard - y).detach() + y
return y
import numpy as np
import pytest
import torch
from ding.torch_utils.network import GumbelSoftmax, gumbel_softmax
@pytest.mark.unittest
class TestGumbelSoftmax:
def test(self):
B = 4
N = 10
model = GumbelSoftmax()
# data case 1
for _ in range(N):
data = torch.rand((4, 10))
data = torch.log(data)
gumbelsoftmax = model(data, hard=False)
assert gumbelsoftmax.shape == (B, N)
# data case 2
for _ in range(N):
data = torch.rand((4, 10))
data = torch.log(data)
gumbelsoftmax = model(data, hard=True)
assert gumbelsoftmax.shape == (B, N)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册