提交 dda9c355 编写于 作者: G gongweibao

fix

上级 45efc1dd
......@@ -57,11 +57,12 @@ A layer for sampling id from multinomial distribution from the
.SetDefault(0.0f);
AddAttr<float>("max", "Maximun value of random. (float, default 1.0).")
.SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed used for the random number engine. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. [default 0].")
AddAttr<int>(
"seed",
"Random seed used for the random number engine. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. (int, default 0).")
.SetDefault(0);
}
};
......
......@@ -6263,7 +6263,7 @@ def gaussian_random(shape,
return out
def sampling_id(x, min=0.0, max=1.0, seed=0):
def sampling_id(x, min=0.0, max=1.0, seed=0, dtype='float32'):
"""
SamplingId Operator.
......@@ -6276,6 +6276,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0):
max (Float): Maximun value of random.
seed (Float): random seed used for the random number engine.0 means use a seed generated by the system.
Note that if seed is not 0, this operator will always generate the same random numbers every time.
dtype(np.dtype|core.VarDesc.VarType|str): The type of output data : float32, float_16, int etc
Returns:
out (Variable): Output of this operator.
......@@ -6283,7 +6284,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0):
"""
helper = LayerHelper('sampling_id', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype('x'))
out = helper.create_tmp_variable(dtype)
helper.append_op(
type='sampling_id',
inputs={'X': x},
......@@ -6314,7 +6315,7 @@ def gaussian_random_batch_size_like(input,
mean (Float): The mean (or center) of the gaussian distribution.
std (Float): The standard deviation (std, or spread) of the gaussian distribution.
seed (Int): Random seed of generator.0 means use system wide seed._note that if seed is not 0, this operator will always generate the same random numbers every time.
dtype(np.dtype|core.VarDesc.VarType|str): Output data type.
dtype(np.dtype|core.VarDesc.VarType|str): The type of output data : float32, float_16, int etc
Returns:
out (Variable): Output of this operator
......
......@@ -614,7 +614,11 @@ class TestBook(unittest.TestCase):
def test_sampling_id(self):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[13, 11], dtype='float32')
x = layers.data(
name="X",
shape=[13, 11],
dtype='float32',
append_batch_size=False)
out = layers.sampling_id(x)
self.assertIsNotNone(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册