提交 6d821235 编写于 作者: Z zhangxinfeng3

update some comments of api

上级 b7c92aa8
此差异已折叠。
......@@ -25,6 +25,9 @@ class ClassWrap:
def __init__(self, cls):
self._cls = cls
self.bnn_loss_file = None
self.__doc__ = cls.__doc__
self.__name__ = cls.__name__
self.__bases__ = cls.__bases__
def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor):
obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor)
......
......@@ -31,7 +31,7 @@ class ConditionalVAE(Cell):
Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor
should be :math:`(N, hidden_size)`.
should be :math:`(N, hidden\_size)`.
The latent_size should be less than or equal to the hidden_size.
Args:
......@@ -42,8 +42,8 @@ class ConditionalVAE(Cell):
num_classes(int): The number of classes.
Inputs:
- **input_x** (Tensor) - the same shape as the input of encoder.
- **input_y** (Tensor) - the tensor of the target data, the shape is :math:`(N, 1)`.
- **input_x** (Tensor) - the same shape as the input of encoder, the shape is :math:`(N, C, H, W)`.
- **input_y** (Tensor) - the tensor of the target data, the shape is :math:`(N,)`.
Outputs:
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
......@@ -99,7 +99,7 @@ class ConditionalVAE(Cell):
Randomly sample from latent space to generate sample.
Args:
sample_y (Tensor): Define the label of sample, int tensor.
sample_y (Tensor): Define the label of sample, int tensor, the shape is (generate_nums, ).
generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W).
......@@ -121,8 +121,8 @@ class ConditionalVAE(Cell):
Reconstruct sample from original data.
Args:
x (Tensor): The input tensor to be reconstructed.
y (Tensor): The label of the input tensor.
x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
y (Tensor): The label of the input tensor, the shape is (N,).
Returns:
Tensor, the reconstructed sample.
......
......@@ -29,7 +29,7 @@ class VAE(Cell):
Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor
should be :math:`(N, hidden_size)`.
should be :math:`(N, hidden\_size)`.
The latent_size should be less than or equal to the hidden_size.
Args:
......@@ -39,7 +39,7 @@ class VAE(Cell):
latent_size(int): The size of the latent space.
Inputs:
- **input** (Tensor) - the same shape as the input of encoder.
- **input** (Tensor) - the same shape as the input of encoder, the shape is :math:`(N, C, H, W)`.
Outputs:
- **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
......@@ -106,7 +106,7 @@ class VAE(Cell):
Reconstruct sample from original data.
Args:
x (Tensor): The input tensor to be reconstructed.
x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
Returns:
Tensor, the reconstructed sample.
......
......@@ -37,7 +37,7 @@ class ELBO(Cell):
Inputs:
- **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
- **target_data** (Tensor) - the target tensor.
- **target_data** (Tensor) - the target tensor of shape :math:`(N,)`.
Outputs:
Tensor, loss float tensor.
......
......@@ -98,7 +98,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
return mnist_ds
def test_svi_cave():
def test_svi_cvae():
# define the encoder and decoder
encoder = Encoder(num_classes=10)
decoder = Decoder()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册