提交 b958f15d 编写于 作者: L lihongkang

fix bugs

上级 0c316e52
......@@ -65,18 +65,22 @@ class Dropout(Cell):
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
Raises:
ValueError: If `keep_prob` is not in range (0, 1).
ValueError: If `keep_prob` is not in range (0, 1].
Inputs:
- **input** (Tensor) - An N-D Tensor.
- **input** (Tensor) - The input tensor.
Outputs:
Tensor, output tensor with the same shape as the input.
Examples:
>>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32)
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> net = nn.Dropout(keep_prob=0.8)
>>> net(x)
[[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]]]
"""
def __init__(self, keep_prob=0.5, seed0=0, seed1=0, dtype=mstype.float32):
......@@ -84,6 +88,7 @@ class Dropout(Cell):
if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob
self.seed0 = seed0
self.seed1 = seed1
......@@ -107,8 +112,7 @@ class Dropout(Cell):
return x
shape = self.get_shape(x)
dtype = P.DType()(x)
keep_prob = self.cast(self.keep_prob, dtype)
keep_prob = self.cast(self.keep_prob, mstype.float32)
output = self.dropout_gen_mask(shape, keep_prob)
return self.dropout_do_mask(x, output, keep_prob)
......
......@@ -585,9 +585,18 @@ class GroupNorm(Cell):
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> goup_norm_op = nn.GroupNorm(16, 64)
>>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
>>> goup_norm_op = nn.GroupNorm(2, 2)
>>> x = Tensor(np.ones([1, 2, 4, 4], np.float32))
>>> goup_norm_op(x)
[[[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]]]
"""
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
......
......@@ -1360,7 +1360,7 @@ class Tile(PrimitiveWithInfer):
- **multiples** (tuple[int]) - The input tuple is constructed by multiple
integers, i.e., :math:`(y_1, y_2, ..., y_S)`. The length of `multiples`
can't be smaller than the length of shape in `input_x`.
can't be smaller than the length of shape in `input_x`. Only constant value is allowed.
Outputs:
Tensor, has the same type as the `input_x`.
......@@ -1400,7 +1400,7 @@ class Tile(PrimitiveWithInfer):
def __infer__(self, x, multiples):
multiples_v = multiples['value']
x_shp = x['shape']
validator.check_value_type("shape", multiples_v, [tuple], self.name)
validator.check_value_type("multiples", multiples_v, [tuple], self.name)
for i, multiple in enumerate(multiples_v):
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name)
......
......@@ -1382,10 +1382,10 @@ class Exp(PrimitiveWithInfer):
Returns exponential of a tensor element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor.
- **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32.
Outputs:
Tensor, has the same shape as the `input_x`.
Tensor, has the same shape and dtype as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
......@@ -1452,7 +1452,7 @@ class HistogramFixedWidth(PrimitiveWithInfer):
width and determined by the arguments range and nbins.
Args:
dtype (string): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32".
dtype (str): An optional attribute. Must be one of the following types: "int32", "int64". Default: "int32".
nbins (int): The number of histogram bins, the type is a positive integer.
Inputs:
......
......@@ -264,6 +264,9 @@ class IOU(PrimitiveWithInfer):
>>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
>>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
>>> iou(anchor_boxes, gt_boxes)
[[0.0, 65504, 65504],
[0.0, 0.0, 0.0],
[0.22253, 0.0, 0.0]]
"""
@prim_attr_register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册