提交 10f381d6 编写于 作者: P peixu_ren

Modify the name of parameters in uniform

上级 f3fd7a55
...@@ -92,55 +92,55 @@ def normal(shape, mean, stddev, seed=0): ...@@ -92,55 +92,55 @@ def normal(shape, mean, stddev, seed=0):
value = random_normal * stddev + mean value = random_normal * stddev + mean
return value return value
def uniform(shape, a, b, seed=0, dtype=mstype.float32): def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32):
""" """
Generates random numbers according to the Uniform random number distribution. Generates random numbers according to the Uniform random number distribution.
Note: Note:
The number in tensor a should be strictly less than b at any position after broadcasting. The number in tensor minval should be strictly less than maxval at any position after broadcasting.
Args: Args:
shape (tuple): The shape of random tensor to be generated. shape (tuple): The shape of random tensor to be generated.
a (Tensor): The a distribution parameter. minval (Tensor): The a distribution parameter.
It defines the minimum possibly generated value. With int32 or float32 data type. It defines the minimum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed. If dtype is int32, only one number is allowed.
b (Tensor): The b distribution parameter. maxval (Tensor): The b distribution parameter.
It defines the maximum possibly generated value. With int32 or float32 data type. It defines the maximum possibly generated value. With int32 or float32 data type.
If dtype is int32, only one number is allowed. If dtype is int32, only one number is allowed.
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
Must be non-negative. Default: 0. Must be non-negative. Default: 0.
Returns: Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b. Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval.
The dtype is designated as the input `dtype`. The dtype is designated as the input `dtype`.
Examples: Examples:
>>> For discrete uniform distribution, only one number is allowed for both a and b: >>> For discrete uniform distribution, only one number is allowed for both minval and maxval:
>>> shape = (4, 2) >>> shape = (4, 2)
>>> a = Tensor(1, mstype.int32) >>> minval = Tensor(1, mstype.int32)
>>> b = Tensor(2, mstype.int32) >>> maxval = Tensor(2, mstype.int32)
>>> output = C.uniform(shape, a, b, seed=5) >>> output = C.uniform(shape, minval, maxval, seed=5)
>>> >>>
>>> For continuous uniform distribution, a and b can be multi-dimentional: >>> For continuous uniform distribution, minval and maxval can be multi-dimentional:
>>> shape = (4, 2) >>> shape = (4, 2)
>>> a = Tensor([1.0, 2.0], mstype.float32) >>> minval = Tensor([1.0, 2.0], mstype.float32)
>>> b = Tensor([4.0, 5.0], mstype.float32) >>> maxval = Tensor([4.0, 5.0], mstype.float32)
>>> output = C.uniform(shape, a, b, seed=5) >>> output = C.uniform(shape, minval, maxval, seed=5)
""" """
a_dtype = F.dtype(a) minval_dtype = F.dtype(minval)
b_dtype = F.dtype(b) maxval_dtype = F.dtype(maxval)
const_utils.check_tensors_dtype_same(a_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(b_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
const_utils.check_non_negative("seed", seed, "uniform") const_utils.check_non_negative("seed", seed, "uniform")
seed1 = get_seed() seed1 = get_seed()
seed2 = seed seed2 = seed
if const_utils.is_same_type(dtype, mstype.int32): if const_utils.is_same_type(dtype, mstype.int32):
random_uniform = P.UniformInt(seed1, seed2) random_uniform = P.UniformInt(seed1, seed2)
value = random_uniform(shape, a, b) value = random_uniform(shape, minval, maxval)
else: else:
uniform_real = P.UniformReal(seed1, seed2) uniform_real = P.UniformReal(seed1, seed2)
random_uniform = uniform_real(shape) random_uniform = uniform_real(shape)
value = random_uniform * (b - a) + a value = random_uniform * (maxval - minval) + minval
return value return value
def gamma(shape, alpha, beta, seed=0): def gamma(shape, alpha, beta, seed=0):
......
...@@ -224,14 +224,14 @@ class Poisson(PrimitiveWithInfer): ...@@ -224,14 +224,14 @@ class Poisson(PrimitiveWithInfer):
class UniformInt(PrimitiveWithInfer): class UniformInt(PrimitiveWithInfer):
r""" r"""
Produces random integer values i, uniformly distributed on the closed interval [a, b), that is, Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is,
distributed according to the discrete probability function: distributed according to the discrete probability function:
.. math:: .. math::
\text{P}(i|a,b) = \frac{1}{b-a+1}, \text{P}(i|a,b) = \frac{1}{b-a+1},
Note: Note:
The number in tensor a should be strictly less than b at any position after broadcasting. The number in tensor minval should be strictly less than maxval at any position after broadcasting.
Args: Args:
seed (int): Random seed. Must be non-negative. Default: 0. seed (int): Random seed. Must be non-negative. Default: 0.
...@@ -239,9 +239,9 @@ class UniformInt(PrimitiveWithInfer): ...@@ -239,9 +239,9 @@ class UniformInt(PrimitiveWithInfer):
Inputs: Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed. - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
- **a** (Tensor) - The a distribution parameter. - **minval** (Tensor) - The a distribution parameter.
It defines the minimum possibly generated value. With int32 data type. Only one number is supported. It defines the minimum possibly generated value. With int32 data type. Only one number is supported.
- **b** (Tensor) - The b distribution parameter. - **maxval** (Tensor) - The b distribution parameter.
It defines the maximum possibly generated value. With int32 data type. Only one number is supported. It defines the maximum possibly generated value. With int32 data type. Only one number is supported.
Outputs: Outputs:
...@@ -249,32 +249,32 @@ class UniformInt(PrimitiveWithInfer): ...@@ -249,32 +249,32 @@ class UniformInt(PrimitiveWithInfer):
Examples: Examples:
>>> shape = (4, 16) >>> shape = (4, 16)
>>> a = Tensor(1, mstype.int32) >>> minval = Tensor(1, mstype.int32)
>>> b = Tensor(5, mstype.int32) >>> maxval = Tensor(5, mstype.int32)
>>> uniform_int = P.UniformInt(seed=10) >>> uniform_int = P.UniformInt(seed=10)
>>> output = uniform_int(shape, a, b) >>> output = uniform_int(shape, minval, maxval)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, seed=0, seed2=0): def __init__(self, seed=0, seed2=0):
"""Init UniformInt""" """Init UniformInt"""
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output']) self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
validator.check_integer("seed", seed, 0, Rel.GE, self.name) validator.check_integer("seed", seed, 0, Rel.GE, self.name)
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
def __infer__(self, shape, a, b): def __infer__(self, shape, minval, maxval):
shape_v = shape["value"] shape_v = shape["value"]
if shape_v is None: if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.") raise ValueError(f"For {self.name}, shape must be const.")
validator.check_value_type("shape", shape_v, [tuple], self.name) validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name) validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name) validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name) validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
a_shape = a['shape'] minval_shape = minval['shape']
b_shape = b['shape'] maxval_shape = maxval['shape']
validator.check("dim of a", len(a_shape), '0(scalar)', 0, Rel.EQ, self.name) validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
validator.check("dim of b", len(b_shape), '0(scalar)', 0, Rel.EQ, self.name) validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
out = { out = {
'shape': shape_v, 'shape': shape_v,
'dtype': mstype.int32, 'dtype': mstype.int32,
......
...@@ -28,28 +28,28 @@ class Net(nn.Cell): ...@@ -28,28 +28,28 @@ class Net(nn.Cell):
self.uniformint = P.UniformInt(seed=seed) self.uniformint = P.UniformInt(seed=seed)
self.shape = shape self.shape = shape
def construct(self, a, b): def construct(self, minval, maxval):
return self.uniformint(self.shape, a, b) return self.uniformint(self.shape, minval, maxval)
def test_net_1D(): def test_net_1D():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1 minval = 1
b = 5 maxval = 5
net = Net(shape, seed=seed) net = Net(shape, seed=seed)
ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32) tminval, tmaxval = Tensor(minval, mstype.int32), Tensor(maxval, mstype.int32)
output = net(ta, tb) output = net(tminval, tmaxval)
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
def test_net_ND(): def test_net_ND():
seed = 10 seed = 10
shape = (3, 2, 1) shape = (3, 2, 1)
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32) minval = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32)
b = np.array([10]).astype(np.int32) maxval = np.array([10]).astype(np.int32)
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a), Tensor(b) tminval, tmaxval = Tensor(minval), Tensor(maxval)
output = net(ta, tb) output = net(tminval, tmaxval)
print(output.asnumpy()) print(output.asnumpy())
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)
...@@ -29,28 +29,28 @@ class Net(nn.Cell): ...@@ -29,28 +29,28 @@ class Net(nn.Cell):
self.shape = shape self.shape = shape
self.seed = seed self.seed = seed
def construct(self, a, b): def construct(self, minval, maxval):
C.set_seed(20) C.set_seed(20)
return C.uniform(self.shape, a, b, self.seed) return C.uniform(self.shape, minval, maxval, self.seed)
def test_net_1D(): def test_net_1D():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 1.0 minval = 1.0
b = 6.0 maxval = 6.0
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) tminval, tmaxval = Tensor(minval, mstype.float32), Tensor(maxval, mstype.float32)
output = net(ta, tb) output = net(tminval, tmaxval)
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
def test_net_ND(): def test_net_ND():
seed = 10 seed = 10
shape = (3, 1, 2) shape = (3, 1, 2)
a = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) minval = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
b = np.array([1.0]).astype(np.float32) maxval = np.array([1.0]).astype(np.float32)
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) tminval, tmaxval = Tensor(minval, mstype.float32), Tensor(maxval, mstype.float32)
output = net(ta, tb) output = net(tminval, tmaxval)
assert output.shape == (3, 2, 2) assert output.shape == (3, 2, 2)
...@@ -27,17 +27,17 @@ class Net(nn.Cell): ...@@ -27,17 +27,17 @@ class Net(nn.Cell):
self.uniformreal = P.UniformReal(seed=seed) self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape self.shape = shape
def construct(self, a, b): def construct(self, minval, maxval):
return self.uniformreal(self.shape, a, b) return self.uniformreal(self.shape, minval, maxval)
def test_net_1D(): def test_net_1D():
seed = 10 seed = 10
shape = (3, 2, 4) shape = (3, 2, 4)
a = 0.0 minval = 0.0
b = 1.0 maxval = 1.0
net = Net(shape, seed) net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) tminval, tmaxval = Tensor(minval, mstype.float32), Tensor(maxval, mstype.float32)
output = net(ta, tb) output = net(tminval, tmaxval)
print(output.asnumpy()) print(output.asnumpy())
assert output.shape == (3, 2, 4) assert output.shape == (3, 2, 4)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册