未验证 提交 87694b12 编写于 作者: S silingtong123 提交者: GitHub

add 'seed' arguemnt of randint API (#23809) (#24094)

test=release/2.0-beta, cherry-pick PR #23809: add 'seed' arguemnt of randint API
上级 193d1430
......@@ -42,11 +42,15 @@ class CPURandintKernel : public framework::OpKernel<T> {
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
T* data = out->mutable_data<T>(ctx.GetPlace());
int64_t size = out->numel();
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dist(ctx.Attr<int>("low"),
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_int_distribution<T> dist(ctx.Attr<int>("low"),
ctx.Attr<int>("high") - 1);
for (int64_t i = 0; i < size; ++i) data[i] = dist(gen);
for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);
}
};
......@@ -153,6 +157,12 @@ uniform distribution. The random result is in set [low, high).
"The upper bound on the range of random values to generate.");
AddAttr<int>("dtype", "Output tensor data type. [Default INT64].")
.SetDefault(framework::proto::VarType::INT64);
AddAttr<int>("seed",
"Random seed used for generating samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. [default 0].")
.SetDefault(0);
}
};
......
......@@ -19,25 +19,6 @@
namespace paddle {
namespace operators {
template <typename T>
struct UniformIntGenerator {
T low_, high_;
__host__ __device__ UniformIntGenerator(T low, T high)
: low_(low), high_(high) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(0);
thrust::uniform_int_distribution<T> dist(low_, high_);
rng.discard(n);
T out = dist(rng);
return out;
}
};
// Use std::uniform_int_distribution and thrust::uniform_int_distribution(thrust
// is a std library in CUDA) to
// implement randint.
template <typename T>
class GPURandintKernel : public framework::OpKernel<T> {
public:
......@@ -54,17 +35,34 @@ class GPURandintKernel : public framework::OpKernel<T> {
}
}
platform::CPUPlace cpu;
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
auto* out = context.Output<framework::LoDTensor>("Out");
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
T* data = out->mutable_data<T>(context.GetPlace());
T low = static_cast<T>(context.Attr<int>("low"));
T high = static_cast<T>(context.Attr<int>("high")) - 1;
framework::LoDTensor tensor;
tensor.Resize(out->dims());
tensor.mutable_data(cpu, dtype);
T* data = tensor.mutable_data<T>(cpu);
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = out->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformIntGenerator<T>(low, high));
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
std::random_device rd;
seed = rd();
}
engine.seed(seed);
std::uniform_int_distribution<> dist(context.Attr<int>("low"),
context.Attr<int>("high") - 1);
for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);
if (platform::is_gpu_place(context.GetPlace())) {
// Copy tensor to out
framework::TensorCopy(tensor, context.GetPlace(), out);
}
}
};
......
......@@ -26,7 +26,7 @@ import paddle
def output_hist(out):
hist, _ = np.histogram(out, range=(-5, 10))
hist, _ = np.histogram(out, range=(-10, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
......@@ -41,7 +41,7 @@ class TestRandintOp(OpTest):
self.outputs = {"Out": np.zeros((10000, 784)).astype("float32")}
def init_attrs(self):
self.attrs = {"shape": [10000, 784], "low": -5, "high": 10}
self.attrs = {"shape": [10000, 784], "low": -10, "high": 10, "seed": 10}
self.output_hist = output_hist
def test_check_output(self):
......@@ -51,7 +51,7 @@ class TestRandintOp(OpTest):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
hist, prob, rtol=0, atol=0.001), "hist: " + str(hist))
class TestRandintOpError(unittest.TestCase):
......@@ -90,7 +90,7 @@ class TestRandintOp_attr_tensorlist(OpTest):
self.outputs = {"Out": np.zeros((10000, 784)).astype("int32")}
def init_attrs(self):
self.attrs = {"low": -5, "high": 10}
self.attrs = {"low": -10, "high": 10, "seed": 10}
self.output_hist = output_hist
def test_check_output(self):
......@@ -100,7 +100,7 @@ class TestRandintOp_attr_tensorlist(OpTest):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
hist, prob, rtol=0, atol=0.001), "hist: " + str(hist))
class TestRandint_attr_tensor(OpTest):
......@@ -111,7 +111,7 @@ class TestRandint_attr_tensor(OpTest):
self.outputs = {"Out": np.zeros((10000, 784)).astype("int64")}
def init_attrs(self):
self.attrs = {"low": -5, "high": 10}
self.attrs = {"low": -10, "high": 10, "seed": 10}
self.output_hist = output_hist
def test_check_output(self):
......@@ -121,7 +121,7 @@ class TestRandint_attr_tensor(OpTest):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
hist, prob, rtol=0, atol=0.001), "hist: " + str(hist))
# Test python API
......
......@@ -41,6 +41,7 @@ def randint(low,
dtype=None,
device=None,
stop_gradient=False,
seed=0,
name=None):
"""
This function returns a Tensor filled with random integers from the "discrete uniform" distribution of the
......@@ -66,6 +67,10 @@ def randint(low,
on the GPU or CPU.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable,
default value is False.
seed (int, optional): Random seed used for permute samples. If seed is
equal to 0, it means use a seed generated by the system. Note that
if seed is not 0, this operator will always generate the same random
permutation every time. Default: 0.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -162,6 +167,7 @@ def randint(low,
low = 0
attrs['low'] = low
attrs['high'] = high
attrs['seed'] = seed
if (low >= high):
raise ValueError(
"randint's low must less then high, but received low = {0}, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册