提交 545f9f40 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5028 Add readme and fix some comments of api

Merge pull request !5028 from zhangxinfeng3/master
此差异已折叠。
...@@ -25,6 +25,9 @@ class ClassWrap: ...@@ -25,6 +25,9 @@ class ClassWrap:
def __init__(self, cls): def __init__(self, cls):
self._cls = cls self._cls = cls
self.bnn_loss_file = None self.bnn_loss_file = None
self.__doc__ = cls.__doc__
self.__name__ = cls.__name__
self.__bases__ = cls.__bases__
def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor): def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor):
obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor) obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor)
......
...@@ -31,7 +31,7 @@ class ConditionalVAE(Cell): ...@@ -31,7 +31,7 @@ class ConditionalVAE(Cell):
Note: Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor
should be :math:`(N, hidden_size)`. should be :math:`(N, hidden\_size)`.
The latent_size should be less than or equal to the hidden_size. The latent_size should be less than or equal to the hidden_size.
Args: Args:
...@@ -42,8 +42,8 @@ class ConditionalVAE(Cell): ...@@ -42,8 +42,8 @@ class ConditionalVAE(Cell):
num_classes(int): The number of classes. num_classes(int): The number of classes.
Inputs: Inputs:
- **input_x** (Tensor) - the same shape as the input of encoder. - **input_x** (Tensor) - the same shape as the input of encoder, the shape is :math:`(N, C, H, W)`.
- **input_y** (Tensor) - the tensor of the target data, the shape is :math:`(N, 1)`. - **input_y** (Tensor) - the tensor of the target data, the shape is :math:`(N,)`.
Outputs: Outputs:
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
...@@ -99,7 +99,7 @@ class ConditionalVAE(Cell): ...@@ -99,7 +99,7 @@ class ConditionalVAE(Cell):
Randomly sample from latent space to generate sample. Randomly sample from latent space to generate sample.
Args: Args:
sample_y (Tensor): Define the label of sample, int tensor. sample_y (Tensor): Define the label of sample, int tensor, the shape is (generate_nums, ).
generate_nums (int): The number of samples to generate. generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W). shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W).
...@@ -121,8 +121,8 @@ class ConditionalVAE(Cell): ...@@ -121,8 +121,8 @@ class ConditionalVAE(Cell):
Reconstruct sample from original data. Reconstruct sample from original data.
Args: Args:
x (Tensor): The input tensor to be reconstructed. x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
y (Tensor): The label of the input tensor. y (Tensor): The label of the input tensor, the shape is (N,).
Returns: Returns:
Tensor, the reconstructed sample. Tensor, the reconstructed sample.
......
...@@ -29,7 +29,7 @@ class VAE(Cell): ...@@ -29,7 +29,7 @@ class VAE(Cell):
Note: Note:
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor
should be :math:`(N, hidden_size)`. should be :math:`(N, hidden\_size)`.
The latent_size should be less than or equal to the hidden_size. The latent_size should be less than or equal to the hidden_size.
Args: Args:
...@@ -39,7 +39,7 @@ class VAE(Cell): ...@@ -39,7 +39,7 @@ class VAE(Cell):
latent_size(int): The size of the latent space. latent_size(int): The size of the latent space.
Inputs: Inputs:
- **input** (Tensor) - the same shape as the input of encoder. - **input** (Tensor) - the same shape as the input of encoder, the shape is :math:`(N, C, H, W)`.
Outputs: Outputs:
- **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
...@@ -106,7 +106,7 @@ class VAE(Cell): ...@@ -106,7 +106,7 @@ class VAE(Cell):
Reconstruct sample from original data. Reconstruct sample from original data.
Args: Args:
x (Tensor): The input tensor to be reconstructed. x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
Returns: Returns:
Tensor, the reconstructed sample. Tensor, the reconstructed sample.
......
...@@ -37,7 +37,7 @@ class ELBO(Cell): ...@@ -37,7 +37,7 @@ class ELBO(Cell):
Inputs: Inputs:
- **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). - **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
- **target_data** (Tensor) - the target tensor. - **target_data** (Tensor) - the target tensor of shape :math:`(N,)`.
Outputs: Outputs:
Tensor, loss float tensor. Tensor, loss float tensor.
......
...@@ -98,7 +98,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1, ...@@ -98,7 +98,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
return mnist_ds return mnist_ds
def test_svi_cave(): def test_svi_cvae():
# define the encoder and decoder # define the encoder and decoder
encoder = Encoder(num_classes=10) encoder = Encoder(num_classes=10)
decoder = Decoder() decoder = Decoder()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册