提交 efdcd4c2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!682 fix api comments of some ops

Merge pull request !682 from liuxiao/fix-bug
......@@ -893,8 +893,8 @@ ATTR_MAP(TransposeD) = EMPTY_ATTR_MAP;
// DropOutGenMask
INPUT_MAP(DropOutGenMask) = {{1, INPUT_DESC(shape)}, {2, INPUT_DESC(prob)}};
ATTR_MAP(DropOutGenMask) = {{"seed", ATTR_DESC(seed, AnyTraits<int64_t>())},
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
ATTR_MAP(DropOutGenMask) = {{"Seed0", ATTR_DESC(seed, AnyTraits<int64_t>())},
{"Seed1", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
OUTPUT_MAP(DropOutGenMask) = {{0, OUTPUT_DESC(y)}};
// Pack
......
......@@ -397,9 +397,8 @@ class LayerNorm(Cell):
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axes
`begin_norm_axis ... R - 1` and centering and scaling parameters are calculated over
`begin_params_axis ... R - 1`.
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
......
......@@ -126,7 +126,8 @@ class Adam(Optimizer):
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (float): The Learning rate.
learning_rate (Union[float, Tensor, Iterable]): The Learning rate.
Iterable type is used for the dynamic learning rate.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
......
......@@ -490,6 +490,15 @@ class FusedBatchNorm(Primitive):
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
Examples:
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> scale = Tensor(np.ones([64]), mindspore.float32)
>>> bias = Tensor(np.ones([64]), mindspore.float32)
>>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> op = P.FusedBatchNorm()
>>> output = op(input_x, scale, bias, mean, variance)
"""
@prim_attr_register
......@@ -733,10 +742,17 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
- **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2)`, then the shape is
:math:`(\text{channel_multiplier}, C_{in}, K_1, K_2)`.
:math:`(K, C_{in}, K_1, K_2)`, `K` must be 1.
Outputs:
Tensor of shape :math:`(N, C_{in} * \text{channel_multiplier}, H_{out}, W_{out})`.
Examples:
>>> input = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32)
>>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32)
>>> depthwise_conv2d = P.DepthwiseConv2dNative(channel_multiplier = 3, kernel_size = (3, 3))
>>> output = depthwise_conv2d(input, weight)
>>> assert output.shape() == (10, 96, 30, 30)
"""
@prim_attr_register
......@@ -1655,6 +1671,15 @@ class LayerNorm(Primitive):
The shape is :math:`(N, C)`.
- **updated_gamma** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_beta** (Tensor) - Tensor of shape :math:`(C,)`.
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)
>>> gamma = Tensor(np.ones([3]), mindspore.float32)
>>> beta = Tensor(np.ones([3]), mindspore.float32)
>>> layer_norm = P.LayerNorm()
>>> output = layer_norm(input_x, gamma, beta)
([[-0.22474492, 1., 2.2247488], [-0.22474492, 1., 2.2247488]],
[[2.], [2.]], [[0.6666667], [0.6666667]])
"""
@prim_attr_register
......@@ -2312,11 +2337,13 @@ class Adam(PrimitiveWithInfer):
Inputs:
- **var** (Tensor) - Weights to be updated.
- **m** (Tensor) - The 1st moment vector in the updating formula.
- **m** (Tensor) - The 1st moment vector in the updating formula. Has the same type as `var`.
- **v** (Tensor) - the 2nd moment vector in the updating formula.
Mean square gradients, has the same type as `var`.
- **beta1_power** (float) - :math:`beta_1^t` in the updating formula.
- **beta2_power** (float) - :math:`beta_2^t` in the updating formula.
- **lr** (float) - :math:`l` in the updating formula.
- **lr** (Union[float, Tensor, Iterable]) - :math:`l` in the updating formula.
Iterable type is used for the dynamic learning rate.
- **beta1** (float) - The exponential decay rate for the 1st moment estimates.
- **beta2** (float) - The exponential decay rate for the 2nd moment estimates.
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
......@@ -2328,6 +2355,9 @@ class Adam(PrimitiveWithInfer):
- **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.
- **v** (Tensor) - The same shape and data type as `v`.
Examples:
Please refer to the usage in nn.Adam.
"""
@prim_attr_register
......
......@@ -793,12 +793,12 @@ test_case_nn_ops = [
'desc_bprop': [[5, 5]]}),
('DepthwiseConv2dNative_1', {
'block': P.DepthwiseConv2dNative(3, (3, 3), pad_mode="pad", pad=1, stride=2),
'desc_inputs': [[10, 32, 32, 32], [3, 32, 3, 3]],
'desc_bprop': [[10, 30, 16, 16]]}),
'desc_inputs': [[10, 32, 32, 32], [1, 32, 3, 3]],
'desc_bprop': [[10, 32, 16, 16]]}),
('DepthwiseConv2dNative_2', {
'block': P.DepthwiseConv2dNative(1, (3, 3), pad_mode="same", pad=0, stride=1),
'desc_inputs': [[2592, 2048, 4, 4], [1, 2048, 3, 3]],
'desc_bprop': [[2592, 2048, 2, 2]]}),
'desc_bprop': [[2592, 2048, 4, 4]]}),
('SigmoidCrossEntropyWithLogits', {
'block': P.SigmoidCrossEntropyWithLogits(),
'desc_inputs': [[128, 10], [128, 10]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册