Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b9675acc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b9675acc
编写于
2月 23, 2022
作者:
zhouweiwei2014
提交者:
GitHub
2月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change CUDA implementaion of bernoulli OP (#39732)
* change CUDA implementaion of bernoulli OP * fix CI
上级
69a04209
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
135 addition
and
32 deletion
+135
-32
paddle/fluid/operators/distribution_helper.h
paddle/fluid/operators/distribution_helper.h
+5
-4
paddle/phi/backends/gpu/gpu_launch_config.h
paddle/phi/backends/gpu/gpu_launch_config.h
+1
-0
paddle/phi/kernels/gpu/bernoulli_kernel.cu
paddle/phi/kernels/gpu/bernoulli_kernel.cu
+68
-14
python/paddle/fluid/tests/unittests/test_bernoulli_op.py
python/paddle/fluid/tests/unittests/test_bernoulli_op.py
+39
-0
python/paddle/fluid/tests/unittests/test_exponential_op.py
python/paddle/fluid/tests/unittests/test_exponential_op.py
+6
-5
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
...n/paddle/fluid/tests/unittests/test_gaussian_random_op.py
+5
-4
python/paddle/fluid/tests/unittests/test_poisson_op.py
python/paddle/fluid/tests/unittests/test_poisson_op.py
+6
-1
python/paddle/fluid/tests/unittests/test_uniform_random_op.py
...on/paddle/fluid/tests/unittests/test_uniform_random_op.py
+5
-4
未找到文件。
paddle/fluid/operators/distribution_helper.h
浏览文件 @
b9675acc
...
@@ -180,8 +180,8 @@ struct normal_distribution<double> {
...
@@ -180,8 +180,8 @@ struct normal_distribution<double> {
/******** Launch GPU function of distribution and transformation *********/
/******** Launch GPU function of distribution and transformation *********/
template
<
typename
T
,
typename
DistOp
,
typename
TransformOp
>
template
<
typename
T
,
typename
DistOp
,
typename
TransformOp
>
__global__
void
DistributionKernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
__global__
void
DistributionKernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
DistOp
dist
,
TransformOp
trans
,
DistOp
dist
,
TransformOp
trans
,
T
*
out_data
,
T
*
out_data
)
{
size_t
stride
)
{
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
static
constexpr
int
kCount
=
DistOp
::
kReturnsCount
;
static
constexpr
int
kCount
=
DistOp
::
kReturnsCount
;
#if defined(__NVCC__)
#if defined(__NVCC__)
...
@@ -201,7 +201,8 @@ __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
...
@@ -201,7 +201,8 @@ __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
1
,
TransformOp
>
(
&
result
[
0
],
&
args
[
0
],
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
1
,
TransformOp
>
(
&
result
[
0
],
&
args
[
0
],
trans
);
trans
);
kps
::
WriteData
<
T
,
T
,
kCount
,
1
,
1
,
true
>
(
out_data
+
i
,
&
result
[
0
],
size
-
i
,
kps
::
WriteData
<
T
,
T
,
kCount
,
1
,
1
,
true
>
(
out_data
+
i
,
&
result
[
0
],
size
-
i
,
1
,
total_thread
,
1
);
1
,
stride
,
1
);
__syncthreads
();
}
}
}
}
...
@@ -234,7 +235,7 @@ void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx,
...
@@ -234,7 +235,7 @@ void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx,
DistributionKernel
<
DistributionKernel
<
T
,
DistOp
,
TransformOp
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
T
,
DistOp
,
TransformOp
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
size
,
seed
,
offset
,
dist
,
trans
,
out_data
);
size
,
seed
,
offset
,
dist
,
trans
,
out_data
,
total_thread
);
}
}
#endif
#endif
...
...
paddle/phi/backends/gpu/gpu_launch_config.h
浏览文件 @
b9675acc
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/enforce.h"
#ifdef __HIPCC__
#ifdef __HIPCC__
// HIP results in error or nan if > 256
// HIP results in error or nan if > 256
...
...
paddle/phi/kernels/gpu/bernoulli_kernel.cu
浏览文件 @
b9675acc
...
@@ -12,19 +12,30 @@
...
@@ -12,19 +12,30 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <thrust/transform.h>
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/bernoulli_kernel.h"
#include "paddle/phi/kernels/bernoulli_kernel.h"
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
namespace
phi
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -49,26 +60,69 @@ struct BernoulliCudaFunctor {
...
@@ -49,26 +60,69 @@ struct BernoulliCudaFunctor {
}
}
};
};
// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time
template
<
typename
T
>
__global__
void
bernoulli_cuda_kernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
const
T
*
x_data
,
T
*
out_data
)
{
size_t
thread_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
#if defined(__NVCC__)
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
thread_idx
,
offset
,
&
state
);
#else
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
thread_idx
,
offset
,
&
state
);
#endif
size_t
total_thread
=
gridDim
.
x
*
blockDim
.
x
;
for
(
size_t
i
=
4
*
thread_idx
;
i
<
size
;
i
+=
total_thread
*
4
)
{
paddle
::
distribution
::
uniform_distribution
<
float
>
dist
;
float4
rand
=
dist
(
&
state
);
#pragma unroll
for
(
size_t
j
=
0
;
j
<
4
;
j
++
)
{
size_t
idx
=
i
+
j
;
if
(
idx
<
size
)
{
out_data
[
idx
]
=
static_cast
<
T
>
((
&
rand
.
x
)[
j
]
<=
x_data
[
idx
]);
}
}
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
BernoulliKernel
(
const
Context
&
ctx
,
void
BernoulliKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
DenseTensor
*
out
)
{
auto
numel
=
x
.
numel
();
const
T
*
x_data
=
x
.
data
<
T
>
();
auto
*
x_data
=
x
.
data
<
T
>
();
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
auto
numel
=
x
.
numel
();
auto
gen_cuda
=
ctx
.
GetGenerator
();
auto
gen_cuda
=
ctx
.
GetGenerator
();
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int64_t
gen_offset
=
numel
*
seed_offset
.
second
;
if
(
FLAGS_use_curand
)
{
paddle
::
platform
::
Transform
<
phi
::
GPUContext
>
trans
;
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
12
);
thrust
::
counting_iterator
<
int64_t
>
index_sequence_begin
(
0
);
uint64_t
seed
=
seed_offset
.
first
;
trans
(
ctx
,
uint64_t
offset
=
seed_offset
.
second
;
index_sequence_begin
,
index_sequence_begin
+
numel
,
auto
gpu_config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
numel
,
4
);
x_data
,
size_t
grid_size
=
gpu_config
.
GetGridSize
();
out_data
,
size_t
block_size
=
gpu_config
.
GetBlockSize
();
BernoulliCudaFunctor
<
T
>
(
static_cast
<
int64_t
>
(
seed_offset
.
first
),
static_cast
<
int64_t
>
(
gen_offset
)));
bernoulli_cuda_kernel
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
numel
,
seed
,
offset
,
x_data
,
out_data
);
}
else
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int64_t
gen_offset
=
numel
*
seed_offset
.
second
;
paddle
::
platform
::
Transform
<
phi
::
GPUContext
>
trans
;
thrust
::
counting_iterator
<
int64_t
>
index_sequence_begin
(
0
);
trans
(
ctx
,
index_sequence_begin
,
index_sequence_begin
+
numel
,
x_data
,
out_data
,
BernoulliCudaFunctor
<
T
>
(
static_cast
<
int64_t
>
(
seed_offset
.
first
),
static_cast
<
int64_t
>
(
gen_offset
)));
}
}
}
}
// namespace phi
}
// namespace phi
...
...
python/paddle/fluid/tests/unittests/test_bernoulli_op.py
浏览文件 @
b9675acc
...
@@ -18,6 +18,7 @@ import unittest
...
@@ -18,6 +18,7 @@ import unittest
import
paddle
import
paddle
from
op_test
import
OpTest
from
op_test
import
OpTest
import
numpy
as
np
import
numpy
as
np
import
os
def
output_hist
(
out
):
def
output_hist
(
out
):
...
@@ -68,5 +69,43 @@ class TestBernoulliApi(unittest.TestCase):
...
@@ -68,5 +69,43 @@ class TestBernoulliApi(unittest.TestCase):
hist
,
prob
,
rtol
=
0
,
atol
=
0.01
),
"hist: "
+
str
(
hist
))
hist
,
prob
,
rtol
=
0
,
atol
=
0.01
),
"hist: "
+
str
(
hist
))
class
TestRandomValue
(
unittest
.
TestCase
):
def
test_fixed_random_number
(
self
):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if
not
paddle
.
is_compiled_with_cuda
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
paddle
.
seed
(
100
)
np
.
random
.
seed
(
100
)
x_np
=
np
.
random
.
rand
(
32
,
1024
,
1024
)
x
=
paddle
.
to_tensor
(
x_np
,
dtype
=
'float64'
)
y
=
paddle
.
bernoulli
(
x
).
numpy
()
index0
,
index1
,
index2
=
np
.
nonzero
(
y
)
self
.
assertEqual
(
np
.
sum
(
index0
),
260028995
)
self
.
assertEqual
(
np
.
sum
(
index1
),
8582429431
)
self
.
assertEqual
(
np
.
sum
(
index2
),
8581445798
)
expect
=
[
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
1.
,
1.
,
1.
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
16
,
500
,
500
:
510
],
expect
))
x
=
paddle
.
to_tensor
(
x_np
,
dtype
=
'float32'
)
y
=
paddle
.
bernoulli
(
x
).
numpy
()
index0
,
index1
,
index2
=
np
.
nonzero
(
y
)
self
.
assertEqual
(
np
.
sum
(
index0
),
260092343
)
self
.
assertEqual
(
np
.
sum
(
index1
),
8583509076
)
self
.
assertEqual
(
np
.
sum
(
index2
),
8582778540
)
expect
=
[
0.
,
0.
,
1.
,
1.
,
1.
,
1.
,
0.
,
1.
,
1.
,
1.
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
16
,
500
,
500
:
510
],
expect
))
paddle
.
enable_static
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_exponential_op.py
浏览文件 @
b9675acc
...
@@ -16,6 +16,7 @@ import unittest
...
@@ -16,6 +16,7 @@ import unittest
import
paddle
import
paddle
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
os
paddle
.
enable_static
()
paddle
.
enable_static
()
paddle
.
seed
(
100
)
paddle
.
seed
(
100
)
...
@@ -90,18 +91,18 @@ class TestExponentialAPI(unittest.TestCase):
...
@@ -90,18 +91,18 @@ class TestExponentialAPI(unittest.TestCase):
self
.
assertTrue
(
np
.
min
(
x
.
numpy
())
>=
0
)
self
.
assertTrue
(
np
.
min
(
x
.
numpy
())
>=
0
)
paddle
.
enable_static
()
paddle
.
enable_static
()
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
def
test_fixed_random_number
(
self
):
def
test_fixed_random_number
(
self
):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if
not
paddle
.
is_compiled_with_cuda
():
if
not
paddle
.
is_compiled_with_cuda
():
return
return
# Note(zhouwei): The Number of threads is determined by
# Different GPU generatte different random value. Only test V100 here.
# 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different
# GPU have different number of threads, which result in different
# random value. Only test on V100 GPU here.
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on V100 GPU------>"
)
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
paddle
.
set_device
(
'gpu'
)
...
...
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
浏览文件 @
b9675acc
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
...
@@ -293,13 +294,13 @@ class TestRandomValue(unittest.TestCase):
...
@@ -293,13 +294,13 @@ class TestRandomValue(unittest.TestCase):
if
not
paddle
.
is_compiled_with_cuda
():
if
not
paddle
.
is_compiled_with_cuda
():
return
return
# Note(zhouwei): The Number of threads is determined by
# Different GPU generatte different random value. Only test V100 here.
# 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different
# GPU have different number of threads, which result in different
# random value. Only test on V100 GPU here.
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
def
_check_random_value
(
dtype
,
expect
,
expect_mean
,
expect_std
):
def
_check_random_value
(
dtype
,
expect
,
expect_mean
,
expect_std
):
x
=
paddle
.
randn
([
32
,
3
,
1024
,
1024
],
dtype
=
dtype
)
x
=
paddle
.
randn
([
32
,
3
,
1024
,
1024
],
dtype
=
dtype
)
actual
=
x
.
numpy
()
actual
=
x
.
numpy
()
...
...
python/paddle/fluid/tests/unittests/test_poisson_op.py
浏览文件 @
b9675acc
...
@@ -17,6 +17,7 @@ import paddle
...
@@ -17,6 +17,7 @@ import paddle
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
math
import
math
import
os
paddle
.
enable_static
()
paddle
.
enable_static
()
paddle
.
seed
(
100
)
paddle
.
seed
(
100
)
...
@@ -101,11 +102,15 @@ class TestPoissonAPI(unittest.TestCase):
...
@@ -101,11 +102,15 @@ class TestPoissonAPI(unittest.TestCase):
self
.
assertTrue
(
np
.
min
(
y
.
numpy
())
>=
0
)
self
.
assertTrue
(
np
.
min
(
y
.
numpy
())
>=
0
)
paddle
.
enable_static
()
paddle
.
enable_static
()
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
def
test_fixed_random_number
(
self
):
def
test_fixed_random_number
(
self
):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if
not
paddle
.
is_compiled_with_cuda
():
if
not
paddle
.
is_compiled_with_cuda
():
return
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on GPU------>"
)
paddle
.
disable_static
()
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
paddle
.
set_device
(
'gpu'
)
paddle
.
seed
(
2021
)
paddle
.
seed
(
2021
)
...
...
python/paddle/fluid/tests/unittests/test_uniform_random_op.py
浏览文件 @
b9675acc
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
sys
import
sys
import
os
import
subprocess
import
subprocess
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -568,13 +569,13 @@ class TestRandomValue(unittest.TestCase):
...
@@ -568,13 +569,13 @@ class TestRandomValue(unittest.TestCase):
if
not
paddle
.
is_compiled_with_cuda
():
if
not
paddle
.
is_compiled_with_cuda
():
return
return
# Note(zhouwei): The Number of threads is determined by
# Different GPU generate different random value. Only test V100 here.
# 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different
# GPU have different number of threads, which result in different
# random value. Only test on V100 GPU here.
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
def
_check_random_value
(
dtype
,
expect
,
expect_mean
,
expect_std
):
def
_check_random_value
(
dtype
,
expect
,
expect_mean
,
expect_std
):
x
=
paddle
.
rand
([
32
,
3
,
1024
,
1024
],
dtype
=
dtype
)
x
=
paddle
.
rand
([
32
,
3
,
1024
,
1024
],
dtype
=
dtype
)
actual
=
x
.
numpy
()
actual
=
x
.
numpy
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录