Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0a9f9f93
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a9f9f93
编写于
9月 04, 2020
作者:
Y
yaoxuefeng
提交者:
GitHub
9月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cuda generator (#26786) (#27014)
上级
09ede3b4
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
523 addition
and
18 deletion
+523
-18
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/generator.cc
paddle/fluid/framework/generator.cc
+53
-0
paddle/fluid/framework/generator.h
paddle/fluid/framework/generator.h
+22
-0
paddle/fluid/operators/bernoulli_op.cu
paddle/fluid/operators/bernoulli_op.cu
+0
-1
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+47
-0
paddle/fluid/operators/gaussian_random_op.cu
paddle/fluid/operators/gaussian_random_op.cu
+53
-7
paddle/fluid/operators/randint_op.cu
paddle/fluid/operators/randint_op.cu
+10
-1
paddle/fluid/operators/truncated_gaussian_random_op.cu
paddle/fluid/operators/truncated_gaussian_random_op.cu
+53
-0
paddle/fluid/operators/uniform_random_op.cu
paddle/fluid/operators/uniform_random_op.cu
+56
-5
paddle/fluid/pybind/generator_py.cc
paddle/fluid/pybind/generator_py.cc
+3
-2
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/test_cuda_random_seed.py
python/paddle/fluid/tests/unittests/test_cuda_random_seed.py
+163
-0
python/paddle/framework/random.py
python/paddle/framework/random.py
+60
-1
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
0a9f9f93
...
...
@@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib
cc_library
(
save_load_util SRCS save_load_util DEPS tensor scope layer
)
cc_test
(
save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer
)
cc_library
(
generator SRCS generator.cc
)
cc_library
(
generator SRCS generator.cc
DEPS enforce place
)
# Get the current working branch
execute_process
(
...
...
paddle/fluid/framework/generator.cc
浏览文件 @
0a9f9f93
...
...
@@ -21,10 +21,46 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
const
std
::
shared_ptr
<
Generator
>&
GetDefaultCUDAGenerator
(
int64_t
device_id
)
{
#ifdef PADDLE_WITH_CUDA
static
int64_t
num_cuda_devices
=
-
1
;
static
std
::
once_flag
num_devices_init_flag
;
static
std
::
deque
<
std
::
once_flag
>
cuda_device_flags
;
static
std
::
vector
<
std
::
shared_ptr
<
Generator
>>
default_cuda_generators
;
std
::
call_once
(
num_devices_init_flag
,
[]()
{
num_cuda_devices
=
paddle
::
platform
::
GetCUDADeviceCount
();
cuda_device_flags
.
resize
(
num_cuda_devices
);
default_cuda_generators
.
resize
(
num_cuda_devices
);
});
if
(
device_id
<
0
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"cuda device id shoule be greater than 0"
));
}
std
::
call_once
(
cuda_device_flags
[
device_id
],
[
device_id
]()
{
default_cuda_generators
[
device_id
]
=
std
::
make_shared
<
Generator
>
(
GetRandomSeed
(),
device_id
);
VLOG
(
4
)
<<
"initial seed: "
<<
default_cuda_generators
[
device_id
]
->
GetCurrentSeed
();
});
return
default_cuda_generators
[
device_id
];
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"getDefaultCUDAGenerator only support in CUDA place"
));
#endif
}
const
std
::
shared_ptr
<
Generator
>&
DefaultCPUGenerator
()
{
static
auto
default_cpu_generator
=
std
::
make_shared
<
Generator
>
(
GetRandomSeed
());
...
...
@@ -103,6 +139,7 @@ uint64_t Generator::Seed() {
void
Generator
::
SetCurrentSeed
(
uint64_t
seed
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mu_
);
this
->
state_
.
current_seed
=
seed
;
this
->
state_
.
thread_offset
=
0
;
std
::
seed_seq
seq
({
seed
});
this
->
engine_
->
seed
(
seq
);
}
...
...
@@ -123,6 +160,22 @@ uint64_t Generator::Random64() {
return
(
*
engine
)();
}
std
::
pair
<
uint64_t
,
uint64_t
>
Generator
::
IncrementOffset
(
uint64_t
increament_offset
)
{
uint64_t
cur_offset
=
this
->
state_
.
thread_offset
;
#ifdef PADDLE_WITH_CUDA
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mu_
);
this
->
state_
.
thread_offset
+=
increament_offset
;
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Increment Offset only support in CUDA place"
));
#endif
return
std
::
make_pair
(
static_cast
<
int
>
(
this
->
state_
.
current_seed
),
cur_offset
);
}
void
Generator
::
SetIsInitPy
(
bool
is_init_py
)
{
this
->
is_init_py_
=
is_init_py
;
VLOG
(
4
)
<<
"SetIsInitPy:"
<<
this
->
is_init_py_
;
...
...
paddle/fluid/framework/generator.h
浏览文件 @
0a9f9f93
...
...
@@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() {
struct
GeneratorState
{
int64_t
device
=
-
1
;
uint64_t
current_seed
=
34342423252
;
uint64_t
thread_offset
=
0
;
std
::
mt19937_64
cpu_engine
;
};
...
...
@@ -49,6 +50,7 @@ struct Generator {
this
->
state_
.
cpu_engine
=
*
engine
;
this
->
state_
.
device
=
-
1
;
this
->
state_
.
current_seed
=
seed
;
this
->
state_
.
thread_offset
=
0
;
this
->
engine_
=
engine
;
VLOG
(
4
)
<<
"initial seed: "
<<
this
->
state_
.
current_seed
<<
", cpu engine: "
<<
&
this
->
state_
.
cpu_engine
;
...
...
@@ -59,11 +61,25 @@ struct Generator {
this
->
state_
.
cpu_engine
=
*
engine
;
this
->
state_
.
device
=
-
1
;
this
->
state_
.
current_seed
=
seed
;
this
->
state_
.
thread_offset
=
0
;
this
->
engine_
=
engine
;
VLOG
(
4
)
<<
"initial seed: "
<<
this
->
state_
.
current_seed
<<
", cpu engine: "
<<
&
this
->
state_
.
cpu_engine
;
this
->
is_init_py_
=
true
;
// TODO(zhiqiu): remove it in future
}
Generator
(
uint64_t
seed
,
uint64_t
device_id
)
{
std
::
seed_seq
seq
({
seed
});
auto
engine
=
std
::
make_shared
<
std
::
mt19937_64
>
(
seq
);
this
->
state_
.
cpu_engine
=
*
engine
;
this
->
state_
.
device
=
device_id
;
this
->
state_
.
current_seed
=
seed
;
this
->
state_
.
thread_offset
=
0
;
this
->
engine_
=
engine
;
VLOG
(
4
)
<<
"initial seed: "
<<
this
->
state_
.
current_seed
<<
", cpu engine: "
<<
&
this
->
state_
.
cpu_engine
;
this
->
is_init_py_
=
false
;
// TODO(zhiqiu): remove it in future
}
Generator
(
const
Generator
&
other
)
=
delete
;
// get random state
...
...
@@ -83,8 +99,11 @@ struct Generator {
uint64_t
Random64
();
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
increament_offset
);
void
SetIsInitPy
(
bool
);
bool
GetIsInitPy
()
const
;
uint64_t
get_device_id
()
{
return
this
->
state_
.
device
;
}
private:
GeneratorState
state_
;
...
...
@@ -105,5 +124,8 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();
std
::
shared_ptr
<
std
::
mt19937_64
>
GetCPURandomEngine
(
uint64_t
);
const
std
::
shared_ptr
<
Generator
>&
GetDefaultCUDAGenerator
(
int64_t
device_id
=
-
1
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/operators/bernoulli_op.cu
浏览文件 @
0a9f9f93
...
...
@@ -16,7 +16,6 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/bernoulli_op.h"
...
...
paddle/fluid/operators/dropout_op.cu
浏览文件 @
0a9f9f93
...
...
@@ -96,6 +96,42 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
}
}
template
<
typename
T
,
typename
MaskType
>
__global__
void
RandomGeneratorWithGenerator
(
const
size_t
n
,
uint64_t
seed
,
const
float
dropout_prob
,
const
T
*
src
,
MaskType
*
mask_data
,
T
*
dst
,
bool
is_upscale_in_train
,
uint64_t
increment
)
{
curandStatePhilox4_32_10_t
state
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
step_size
=
0
;
MaskType
mask
;
T
dest
;
for
(;
idx
<
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
src
[
idx
];
if
(
step_size
==
0
)
{
curand_init
(
seed
,
idx
,
increment
,
&
state
);
step_size
=
blockDim
.
x
*
gridDim
.
x
;
}
else
{
curand_init
(
seed
,
idx
,
increment
,
&
state
);
}
if
(
curand_uniform
(
&
state
)
<
dropout_prob
)
{
mask
=
0
;
dest
=
0
;
}
else
{
mask
=
1
;
if
(
is_upscale_in_train
)
{
dest
=
s
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
dest
=
s
;
}
}
mask_data
[
idx
]
=
mask
;
dst
[
idx
]
=
dest
;
}
}
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
...
...
@@ -150,6 +186,17 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
context
.
Attr
<
bool
>
(
"fix_seed"
)
?
context
.
Attr
<
int
>
(
"seed"
)
:
rnd
();
}
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
())
.
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
(
!
context
.
Attr
<
bool
>
(
"fix_seed"
)))
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
RandomGeneratorWithGenerator
<
T
,
uint8_t
><<<
grid
,
threads
,
0
,
stream
>>>
(
size
,
seed_offset
.
first
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
seed_offset
.
second
);
return
;
}
RandomGenerator
<
T
,
uint8_t
><<<
grid
,
threads
,
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
);
...
...
paddle/fluid/operators/gaussian_random_op.cu
浏览文件 @
0a9f9f93
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
...
...
@@ -24,15 +25,20 @@ template <typename T>
struct
GaussianGenerator
{
T
mean_
,
std_
;
unsigned
int
seed_
;
unsigned
int
offset_
=
0
;
__host__
__device__
GaussianGenerator
(
T
mean
,
T
std
,
int
seed
)
:
mean_
(
mean
),
std_
(
std
),
seed_
(
seed
)
{}
__host__
__device__
GaussianGenerator
(
T
mean
,
T
std
,
int
seed
,
int
offset
)
:
mean_
(
mean
),
std_
(
std
),
seed_
(
seed
),
offset_
(
offset
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
normal_distribution
<
T
>
dist
(
mean_
,
std_
);
rng
.
discard
(
n
);
unsigned
int
new_n
=
n
+
offset_
;
rng
.
discard
(
new_n
);
return
dist
(
rng
);
}
};
...
...
@@ -43,9 +49,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
bool
seed_flag
=
false
;
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
seed_flag
=
true
;
}
T
mean
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"std"
));
...
...
@@ -56,9 +64,27 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
size
=
tensor
->
numel
();
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
));
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()).
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
seed_flag
)
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int
offset_step
=
100
;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int
gen_offset
=
offset_step
*
seed_offset
.
second
;
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed_offset
.
first
,
gen_offset
));
}
else
{
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
));
}
}
};
...
...
@@ -69,17 +95,37 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
bool
seed_flag
=
false
;
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
seed_flag
=
true
;
}
T
mean
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int64_t
size
=
tensor
->
numel
();
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
));
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()).
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
seed_flag
)
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int
offset_step
=
100
;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int
gen_offset
=
offset_step
*
seed_offset
.
second
;
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed_offset
.
first
,
seed_offset
.
second
));
}
else
{
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
));
}
}
};
}
// namespace operators
...
...
paddle/fluid/operators/randint_op.cu
浏览文件 @
0a9f9f93
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/uniform_random_op.h"
...
...
@@ -49,15 +50,23 @@ class GPURandintKernel : public framework::OpKernel<T> {
int64_t
size
=
out
->
numel
();
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
/*
std::minstd_rand engine;
if (seed == 0) {
std::random_device rd;
seed = rd();
}
engine.seed(seed);
*/
std
::
uniform_int_distribution
<>
dist
(
context
.
Attr
<
int
>
(
"low"
),
context
.
Attr
<
int
>
(
"high"
)
-
1
);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
data
[
i
]
=
dist
(
engine
);
auto
engine
=
framework
::
GetCPURandomEngine
(
seed
);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
*
engine
);
}
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
()))
{
// Copy tensor to out
...
...
paddle/fluid/operators/truncated_gaussian_random_op.cu
浏览文件 @
0a9f9f93
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include <limits>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
...
...
@@ -46,6 +47,37 @@ struct TruncatedNormal {
}
};
template
<
typename
T
>
struct
TruncatedNormalOffset
{
T
mean
,
std
;
T
a_normal_cdf
;
T
b_normal_cdf
;
unsigned
int
seed
;
T
numeric_min
;
int
offset_
;
__host__
__device__
TruncatedNormalOffset
(
T
mean
,
T
std
,
T
numeric_min
,
int
seed
,
int
offset
)
:
mean
(
mean
),
std
(
std
),
seed
(
seed
),
numeric_min
(
numeric_min
),
offset_
(
offset
)
{
a_normal_cdf
=
(
1.0
+
erff
(
-
2.0
/
sqrtf
(
2.0
)))
/
2.0
;
b_normal_cdf
=
(
1.0
+
erff
(
2.0
/
sqrtf
(
2.0
)))
/
2.0
;
}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
numeric_min
,
1
);
rng
.
discard
(
n
);
T
value
=
dist
(
rng
);
auto
p
=
a_normal_cdf
+
(
b_normal_cdf
-
a_normal_cdf
)
*
value
;
return
std
::
sqrt
(
2.0
)
*
erfinvf
(
2
*
p
-
1
)
*
std
+
mean
;
}
};
template
<
typename
T
>
class
GPUTruncatedGaussianRandomKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -54,14 +86,35 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
bool
seed_flag
=
false
;
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
seed_flag
=
true
;
}
T
mean
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int64_t
size
=
tensor
->
numel
();
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()).
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
seed_flag
)
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int
offset_step
=
100
;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int
gen_offset
=
offset_step
*
seed_offset
.
second
;
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
TruncatedNormalOffset
<
T
>
(
mean
,
std
,
std
::
numeric_limits
<
T
>::
min
(),
seed_offset
.
first
,
seed_offset
.
second
));
}
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
...
...
paddle/fluid/operators/uniform_random_op.cu
浏览文件 @
0a9f9f93
...
...
@@ -51,6 +51,39 @@ struct UniformGenerator {
}
};
template
<
typename
T
>
struct
UniformGeneratorOffset
{
T
min_
,
max_
;
unsigned
int
seed_
;
T
diag_val_
;
unsigned
int
diag_num_
;
unsigned
int
diag_step_
;
int
offset_
;
__host__
__device__
UniformGeneratorOffset
(
T
min
,
T
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
T
diag_val
,
int
offset
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
),
diag_num_
(
diag_num
),
diag_step_
(
diag_step
),
diag_val_
(
diag_val
),
offset_
(
offset
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
+
offset_
);
T
out
=
dist
(
rng
);
unsigned
int
remainder
=
n
%
(
diag_step_
+
1
);
if
(
remainder
==
0
&&
diag_num_
>
n
/
(
diag_step_
+
1
))
{
out
=
diag_val_
;
}
return
out
;
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
...
...
@@ -89,10 +122,11 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
}
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
bool
seed_flag
=
false
;
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
seed_flag
=
true
;
}
T
min
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"min"
));
...
...
@@ -104,10 +138,27 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
T
diag_val
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"diag_val"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int64_t
size
=
tensor
->
numel
();
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
UniformGenerator
<
T
>
(
min
,
max
,
seed
,
diag_num
,
diag_step
,
diag_val
));
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()).
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
seed_flag
)
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int
offset_step
=
100
;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int
gen_offset
=
offset_step
*
seed_offset
.
second
;
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
UniformGeneratorOffset
<
T
>
(
min
,
max
,
seed_offset
.
first
,
diag_num
,
diag_step
,
diag_val
,
gen_offset
));
}
else
{
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
UniformGenerator
<
T
>
(
min
,
max
,
seed
,
diag_num
,
diag_step
,
diag_val
));
}
}
};
...
...
paddle/fluid/pybind/generator_py.cc
浏览文件 @
0a9f9f93
...
...
@@ -59,6 +59,7 @@ void BindGenerator(py::module* m_ptr) {
.
def_property
(
"_is_init_py"
,
&
framework
::
Generator
::
GetIsInitPy
,
&
framework
::
Generator
::
SetIsInitPy
);
m
.
def
(
"default_cpu_generator"
,
&
framework
::
DefaultCPUGenerator
);
}
// end Generator
}
// end namespace pybind
m
.
def
(
"default_cuda_generator"
,
&
framework
::
GetDefaultCUDAGenerator
);
}
}
// namespace pybind
}
// namespace paddle
python/paddle/__init__.py
浏览文件 @
0a9f9f93
...
...
@@ -217,6 +217,8 @@ from .tensor.search import index_select #DEFINE_ALIAS
from
.tensor.search
import
nonzero
#DEFINE_ALIAS
from
.tensor.search
import
sort
#DEFINE_ALIAS
from
.framework.random
import
manual_seed
#DEFINE_ALIAS
from
.framework.random
import
get_cuda_rng_state
#DEFINE_ALIAS
from
.framework.random
import
set_cuda_rng_state
#DEFINE_ALIAS
from
.framework
import
Variable
#DEFINE_ALIAS
from
.framework
import
ParamAttr
#DEFINE_ALIAS
from
.framework
import
create_global_var
#DEFINE_ALIAS
...
...
python/paddle/fluid/tests/unittests/test_cuda_random_seed.py
0 → 100644
浏览文件 @
0a9f9f93
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from
__future__
import
print_function
import
os
import
unittest
import
paddle.fluid.generator
as
generator
import
time
# temp for debug
import
paddle.fluid
as
fluid
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
class
TestGeneratorSeed
(
unittest
.
TestCase
):
"""
Test cases for cpu generator seed.
"""
def
test_gen_dropout_dygraph
(
self
):
gen
=
paddle
.
manual_seed
(
12343
)
fluid
.
enable_dygraph
()
gen
.
manual_seed
(
111111111
)
st
=
paddle
.
get_cuda_rng_state
()
x
=
fluid
.
layers
.
uniform_random
(
[
2
,
10
],
dtype
=
"float32"
,
min
=
0.0
,
max
=
1.0
)
x_again
=
fluid
.
layers
.
uniform_random
(
[
2
,
10
],
dtype
=
"float32"
,
min
=
0.0
,
max
=
1.0
)
x_third
=
fluid
.
layers
.
uniform_random
(
[
2
,
10
],
dtype
=
"float32"
,
min
=
0.0
,
max
=
1.0
)
print
(
"x: {}"
.
format
(
x
.
numpy
()))
print
(
"x_again: {}"
.
format
(
x_again
.
numpy
()))
x
=
x
+
x_again
+
x_third
y
=
fluid
.
layers
.
dropout
(
x
,
0.5
)
paddle
.
set_cuda_rng_state
(
st
)
x1
=
fluid
.
layers
.
uniform_random
(
[
2
,
10
],
dtype
=
"float32"
,
min
=
0.0
,
max
=
1.0
)
x1_again
=
fluid
.
layers
.
uniform_random
(
[
2
,
10
],
dtype
=
"float32"
,
min
=
0.0
,
max
=
1.0
)
x1_third
=
fluid
.
layers
.
uniform_random
(
[
2
,
10
],
dtype
=
"float32"
,
min
=
0.0
,
max
=
1.0
)
x1
=
x1
+
x1_again
+
x1_third
y1
=
fluid
.
layers
.
dropout
(
x1
,
0.5
)
y_np
=
y
.
numpy
()
y1_np
=
y1
.
numpy
()
if
core
.
is_compiled_with_cuda
():
print
(
">>>>>>> dropout dygraph >>>>>>>"
)
self
.
assertTrue
(
np
.
allclose
(
y_np
,
y1_np
))
def
test_generator_gaussian_random_dygraph
(
self
):
"""Test Generator seed."""
fluid
.
enable_dygraph
()
paddle
.
manual_seed
(
12312321111
)
x
=
fluid
.
layers
.
gaussian_random
([
120
],
dtype
=
"float32"
)
st1
=
paddle
.
get_cuda_rng_state
()
x1
=
fluid
.
layers
.
gaussian_random
([
120
],
dtype
=
"float32"
)
paddle
.
set_cuda_rng_state
(
st1
)
x2
=
fluid
.
layers
.
gaussian_random
([
120
],
dtype
=
"float32"
)
paddle
.
manual_seed
(
12312321111
)
x3
=
fluid
.
layers
.
gaussian_random
([
120
],
dtype
=
"float32"
)
x_np
=
x
.
numpy
()
x1_np
=
x1
.
numpy
()
x2_np
=
x2
.
numpy
()
x3_np
=
x3
.
numpy
()
if
core
.
is_compiled_with_cuda
():
print
(
">>>>>>> gaussian random dygraph >>>>>>>"
)
self
.
assertTrue
(
np
.
allclose
(
x1_np
,
x2_np
))
self
.
assertTrue
(
np
.
allclose
(
x_np
,
x3_np
))
def
test_generator_randint_dygraph
(
self
):
"""Test Generator seed."""
fluid
.
enable_dygraph
()
gen
=
paddle
.
manual_seed
(
12312321111
)
x
=
paddle
.
randint
(
low
=
10
,
shape
=
[
10
],
dtype
=
"int32"
)
st1
=
gen
.
get_state
()
x1
=
paddle
.
randint
(
low
=
10
,
shape
=
[
10
],
dtype
=
"int32"
)
gen
.
set_state
(
st1
)
x2
=
paddle
.
randint
(
low
=
10
,
shape
=
[
10
],
dtype
=
"int32"
)
paddle
.
manual_seed
(
12312321111
)
x3
=
paddle
.
randint
(
low
=
10
,
shape
=
[
10
],
dtype
=
"int32"
)
x_np
=
x
.
numpy
()
x1_np
=
x1
.
numpy
()
x2_np
=
x2
.
numpy
()
x3_np
=
x3
.
numpy
()
if
core
.
is_compiled_with_cuda
():
print
(
">>>>>>> randint dygraph >>>>>>>"
)
self
.
assertTrue
(
np
.
allclose
(
x1_np
,
x2_np
))
self
.
assertTrue
(
np
.
allclose
(
x_np
,
x3_np
))
def
test_gen_TruncatedNormal_initializer
(
self
):
fluid
.
disable_dygraph
()
gen
=
paddle
.
manual_seed
(
123123143
)
cur_state
=
paddle
.
get_cuda_rng_state
()
startup_program
=
fluid
.
Program
()
train_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_program
,
startup_program
):
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
x
=
fluid
.
layers
.
uniform_random
(
shape
=
[
2
,
10
])
result_1
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
10
,
param_attr
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
2.0
))
result_2
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
10
,
param_attr
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
2.0
))
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
startup_program
)
out1
=
exe
.
run
(
train_program
,
feed
=
{},
fetch_list
=
[
result_1
,
result_2
])
paddle
.
manual_seed
(
123123143
)
with
fluid
.
program_guard
(
train_program
,
startup_program
):
exe
.
run
(
startup_program
)
out2
=
exe
.
run
(
train_program
,
feed
=
{},
fetch_list
=
[
result_1
,
result_2
])
out1_res1
=
np
.
array
(
out1
[
0
])
out1_res2
=
np
.
array
(
out1
[
1
])
out2_res1
=
np
.
array
(
out2
[
0
])
out2_res2
=
np
.
array
(
out2
[
1
])
if
core
.
is_compiled_with_cuda
():
print
(
">>>>>>> truncated normal static >>>>>>>"
)
self
.
assertTrue
(
np
.
allclose
(
out1_res1
,
out2_res1
))
self
.
assertTrue
(
np
.
allclose
(
out1_res2
,
out2_res2
))
self
.
assertTrue
(
not
np
.
allclose
(
out1_res2
,
out1_res1
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/framework/random.py
浏览文件 @
0a9f9f93
...
...
@@ -16,7 +16,7 @@
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
__all__
=
[
'manual_seed'
]
__all__
=
[
'manual_seed'
,
'get_cuda_rng_state'
,
'set_cuda_rng_state'
]
def
manual_seed
(
seed
):
...
...
@@ -42,10 +42,69 @@ def manual_seed(seed):
seed
=
int
(
seed
)
if
core
.
is_compiled_with_cuda
():
for
i
in
range
(
core
.
get_cuda_device_count
()):
core
.
default_cuda_generator
(
i
).
_is_init_py
=
True
core
.
default_cuda_generator
(
i
).
manual_seed
(
seed
)
core
.
default_cpu_generator
().
_is_init_py
=
True
return
core
.
default_cpu_generator
().
manual_seed
(
seed
)
def
get_cuda_rng_state
():
"""
Get random state of cuda generators.
Args:
None
Returns:
GeneratorState: object.
Examples:
.. code-block:: python
import paddle
sts = paddle.get_cuda_rng_state()
"""
state_list
=
[]
if
core
.
is_compiled_with_cuda
():
for
i
in
range
(
core
.
get_cuda_device_count
()):
state_list
.
append
(
core
.
default_cuda_generator
(
i
).
get_state
())
return
state_list
def
set_cuda_rng_state
(
state_list
):
"""
Sets generator state for all cuda generators
Args:
state_list(list): The cuda states to set back to cuda generators. state_list is obtained from get_cuda_rng_state().
Returns:
None
Examples:
.. code-block:: python
import paddle
sts = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(sts)
"""
if
core
.
is_compiled_with_cuda
():
if
not
len
(
state_list
)
==
core
.
get_cuda_device_count
():
raise
ValueError
(
"Length of cuda state list shoule be equal to the cuda device count"
)
for
i
in
range
(
core
.
get_cuda_device_count
()):
core
.
default_cuda_generator
(
i
).
set_state
(
state_list
[
i
])
def
_manual_program_seed
(
seed
):
"""
Sets global seed for generating random numbers.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录