提交 dda9c355 编写于 作者: G gongweibao

fix

上级 45efc1dd
...@@ -57,11 +57,12 @@ A layer for sampling id from multinomial distribution from the ...@@ -57,11 +57,12 @@ A layer for sampling id from multinomial distribution from the
.SetDefault(0.0f); .SetDefault(0.0f);
AddAttr<float>("max", "Maximun value of random. (float, default 1.0).") AddAttr<float>("max", "Maximun value of random. (float, default 1.0).")
.SetDefault(1.0f); .SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>(
"Random seed used for the random number engine. " "seed",
"0 means use a seed generated by the system." "Random seed used for the random number engine. "
"Note that if seed is not 0, this operator will always " "0 means use a seed generated by the system."
"generate the same random numbers every time. [default 0].") "Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. (int, default 0).")
.SetDefault(0); .SetDefault(0);
} }
}; };
......
...@@ -6263,7 +6263,7 @@ def gaussian_random(shape, ...@@ -6263,7 +6263,7 @@ def gaussian_random(shape,
return out return out
def sampling_id(x, min=0.0, max=1.0, seed=0): def sampling_id(x, min=0.0, max=1.0, seed=0, dtype='float32'):
""" """
SamplingId Operator. SamplingId Operator.
...@@ -6276,6 +6276,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0): ...@@ -6276,6 +6276,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0):
max (Float): Maximun value of random. max (Float): Maximun value of random.
seed (Float): random seed used for the random number engine.0 means use a seed generated by the system. seed (Float): random seed used for the random number engine.0 means use a seed generated by the system.
Note that if seed is not 0, this operator will always generate the same random numbers every time. Note that if seed is not 0, this operator will always generate the same random numbers every time.
dtype(np.dtype|core.VarDesc.VarType|str): The type of output data : float32, float_16, int etc
Returns: Returns:
out (Variable): Output of this operator. out (Variable): Output of this operator.
...@@ -6283,7 +6284,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0): ...@@ -6283,7 +6284,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0):
""" """
helper = LayerHelper('sampling_id', **locals()) helper = LayerHelper('sampling_id', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype('x')) out = helper.create_tmp_variable(dtype)
helper.append_op( helper.append_op(
type='sampling_id', type='sampling_id',
inputs={'X': x}, inputs={'X': x},
...@@ -6314,7 +6315,7 @@ def gaussian_random_batch_size_like(input, ...@@ -6314,7 +6315,7 @@ def gaussian_random_batch_size_like(input,
mean (Float): The mean (or center) of the gaussian distribution. mean (Float): The mean (or center) of the gaussian distribution.
std (Float): The standard deviation (std, or spread) of the gaussian distribution. std (Float): The standard deviation (std, or spread) of the gaussian distribution.
seed (Int): Random seed of generator.0 means use system wide seed._note that if seed is not 0, this operator will always generate the same random numbers every time. seed (Int): Random seed of generator.0 means use system wide seed._note that if seed is not 0, this operator will always generate the same random numbers every time.
dtype(np.dtype|core.VarDesc.VarType|str): Output data type. dtype(np.dtype|core.VarDesc.VarType|str): The type of output data : float32, float_16, int etc
Returns: Returns:
out (Variable): Output of this operator out (Variable): Output of this operator
......
...@@ -614,7 +614,11 @@ class TestBook(unittest.TestCase): ...@@ -614,7 +614,11 @@ class TestBook(unittest.TestCase):
def test_sampling_id(self): def test_sampling_id(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.data(name="X", shape=[13, 11], dtype='float32') x = layers.data(
name="X",
shape=[13, 11],
dtype='float32',
append_batch_size=False)
out = layers.sampling_id(x) out = layers.sampling_id(x)
self.assertIsNotNone(out) self.assertIsNotNone(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册