Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9714878c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9714878c
编写于
4月 07, 2022
作者:
zhouweiwei2014
提交者:
GitHub
4月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove FLAGS_use_curand and change all random op CUDA implementation (#41308)
上级
0d642d3a
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
267 addition
and
625 deletion
+267
-625
paddle/fluid/operators/dropout_impl.cu.h
paddle/fluid/operators/dropout_impl.cu.h
+40
-111
paddle/fluid/operators/gaussian_random_op.cu
paddle/fluid/operators/gaussian_random_op.cu
+0
-7
paddle/fluid/operators/uniform_random_op.h
paddle/fluid/operators/uniform_random_op.h
+4
-50
paddle/fluid/platform/flags.cc
paddle/fluid/platform/flags.cc
+0
-2
paddle/phi/kernels/cpu/transpose_kernel.cc
paddle/phi/kernels/cpu/transpose_kernel.cc
+1
-0
paddle/phi/kernels/gpu/bernoulli_kernel.cu
paddle/phi/kernels/gpu/bernoulli_kernel.cu
+8
-51
paddle/phi/kernels/gpu/gaussian_random_kernel.cu
paddle/phi/kernels/gpu/gaussian_random_kernel.cu
+5
-20
paddle/phi/kernels/gpu/multinomial_kernel.cu
paddle/phi/kernels/gpu/multinomial_kernel.cu
+57
-156
paddle/phi/kernels/gpu/randint_kernel.cu
paddle/phi/kernels/gpu/randint_kernel.cu
+3
-33
paddle/phi/kernels/gpu/randperm_kernel.cu
paddle/phi/kernels/gpu/randperm_kernel.cu
+59
-85
paddle/phi/kernels/gpu/uniform_random_kernel.cu
paddle/phi/kernels/gpu/uniform_random_kernel.cu
+5
-56
paddle/scripts/paddle_build.bat
paddle/scripts/paddle_build.bat
+0
-1
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+0
-2
python/paddle/fluid/initializer.py
python/paddle/fluid/initializer.py
+8
-8
python/paddle/fluid/tests/unittests/test_bernoulli_op.py
python/paddle/fluid/tests/unittests/test_bernoulli_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_exponential_op.py
python/paddle/fluid/tests/unittests/test_exponential_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
...n/paddle/fluid/tests/unittests/test_gaussian_random_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_linear.py
python/paddle/fluid/tests/unittests/test_linear.py
+16
-0
python/paddle/fluid/tests/unittests/test_multinomial_op.py
python/paddle/fluid/tests/unittests/test_multinomial_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_poisson_op.py
python/paddle/fluid/tests/unittests/test_poisson_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_randint_op.py
python/paddle/fluid/tests/unittests/test_randint_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_randperm_op.py
python/paddle/fluid/tests/unittests/test_randperm_op.py
+0
-3
python/paddle/fluid/tests/unittests/test_uniform_random_op.py
...on/paddle/fluid/tests/unittests/test_uniform_random_op.py
+27
-18
python/paddle/nn/utils/__init__.py
python/paddle/nn/utils/__init__.py
+1
-1
python/paddle/nn/utils/transform_parameters.py
python/paddle/nn/utils/transform_parameters.py
+33
-0
未找到文件。
paddle/fluid/operators/dropout_impl.cu.h
浏览文件 @
9714878c
...
...
@@ -38,43 +38,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/functors.h"
DECLARE_bool
(
use_curand
);
namespace
paddle
{
namespace
operators
{
template
<
typename
T1
,
typename
T2
=
T1
,
typename
OutT
=
T1
>
struct
DstMaskGenerator
{
const
float
dropout_prob_
;
const
bool
is_upscale_in_train_
;
using
MT
=
typename
details
::
MPTypeTrait
<
T1
>::
Type
;
MT
factor
;
HOSTDEVICE
inline
DstMaskGenerator
(
const
float
dropout_prob
,
const
bool
is_upscale_in_train
)
:
dropout_prob_
(
dropout_prob
),
is_upscale_in_train_
(
is_upscale_in_train
)
{
factor
=
static_cast
<
MT
>
(
1.0
f
/
(
1.0
f
-
dropout_prob_
));
}
HOSTDEVICE
inline
void
operator
()(
OutT
*
dst
,
const
T1
*
src_val
,
const
T2
*
rand
,
int
num
)
const
{
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
T2
>::
kReturnsCount
;
// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask
#pragma unroll
for
(
int
i
=
0
;
i
<
kCount
;
i
++
)
{
if
(
rand
[
i
]
<
dropout_prob_
)
{
dst
[
i
]
=
static_cast
<
T1
>
(
0
);
dst
[
i
+
kCount
]
=
dst
[
i
];
}
else
{
dst
[
i
]
=
is_upscale_in_train_
?
static_cast
<
T1
>
(
static_cast
<
MT
>
(
src_val
[
i
])
*
factor
)
:
static_cast
<
T1
>
(
src_val
[
i
]);
dst
[
i
+
kCount
]
=
static_cast
<
T1
>
(
1
);
}
}
}
};
template
<
typename
T1
,
typename
T2
=
T1
,
typename
OutT
=
T1
>
struct
DstMaskFunctor
{
const
float
retain_prob_
;
...
...
@@ -113,7 +79,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
const
T
*
src
,
MaskType
*
mask
,
T
*
dst
,
bool
is_upscale_in_train
,
uint64_t
increment
,
size_t
main_offset
,
bool
use_curand
)
{
size_t
main_offset
)
{
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
static
constexpr
int
kCount
=
phi
::
funcs
::
uniform_distribution
<
float
>::
kReturnsCount
;
...
...
@@ -135,76 +101,41 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
int
deal_size
=
BLOCK_NUM_X
*
kCount
;
size_t
fix
=
idx
*
kCount
;
if
(
use_curand
)
{
auto
dst_functor
=
DstMaskFunctor
<
T
,
float
>
(
1.0
f
-
dropout_prob
,
is_upscale_in_train
);
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
false
>
(
&
dst_mask
[
0
],
src
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskFunctor
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
false
>
(
dst
+
fix
,
&
dst_mask
[
0
],
deal_size
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
false
>
(
mask
+
fix
,
&
mask_result
[
0
],
deal_size
);
if
(
fix
>
idx
*
kCount
+
1
)
{
__syncthreads
();
}
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
true
>
(
&
dst_mask
[
0
],
src
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskFunctor
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
true
>
(
dst
+
fix
,
&
dst_mask
[
0
],
remainder
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
true
>
(
mask
+
fix
,
&
mask_result
[
0
],
remainder
);
auto
dst_functor
=
DstMaskFunctor
<
T
,
float
>
(
1.0
f
-
dropout_prob
,
is_upscale_in_train
);
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
false
>
(
&
dst_mask
[
0
],
src
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskFunctor
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
false
>
(
dst
+
fix
,
&
dst_mask
[
0
],
deal_size
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
false
>
(
mask
+
fix
,
&
mask_result
[
0
],
deal_size
);
if
(
fix
>
idx
*
kCount
+
1
)
{
__syncthreads
();
}
}
else
{
auto
dst_functor
=
DstMaskGenerator
<
T
,
float
>
(
dropout_prob
,
is_upscale_in_train
);
for
(;
fix
<
main_offset
;
fix
+=
stride
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
false
>
(
&
dst_mask
[
0
],
src
+
fix
,
deal_size
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
false
>
(
dst
+
fix
,
&
dst_mask
[
0
],
deal_size
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
false
>
(
mask
+
fix
,
&
mask_result
[
0
],
deal_size
);
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
true
>
(
&
dst_mask
[
0
],
src
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskGenerator
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
true
>
(
dst
+
fix
,
&
dst_mask
[
0
],
remainder
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
true
>
(
mask
+
fix
,
&
mask_result
[
0
],
remainder
);
}
}
int
remainder
=
n
-
fix
;
if
(
remainder
>
0
)
{
kps
::
ReadData
<
T
,
kCount
,
1
,
1
,
true
>
(
&
dst_mask
[
0
],
src
+
fix
,
remainder
);
kps
::
ElementwiseRandom
<
SType
,
float
,
kCount
,
1
,
Rand
>
(
&
rands
[
0
],
Rand
(),
&
state
);
// dst
kps
::
OperatorTernary
<
T
,
float
,
T
,
DstMaskFunctor
<
T
,
float
>>
(
&
dst_mask
[
0
],
&
dst_mask
[
0
],
&
rands
[
0
],
dst_functor
,
kCount
);
kps
::
WriteData
<
T
,
kCount
,
1
,
1
,
true
>
(
dst
+
fix
,
&
dst_mask
[
0
],
remainder
);
// mask
kps
::
ElementwiseUnary
<
T
,
MaskType
,
kCount
,
1
,
1
,
Cast
>
(
&
mask_result
[
0
],
&
dst_mask
[
kCount
],
Cast
());
kps
::
WriteData
<
MaskType
,
kCount
,
1
,
1
,
true
>
(
mask
+
fix
,
&
mask_result
[
0
],
remainder
);
__syncthreads
();
}
}
...
...
@@ -251,13 +182,11 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
size_t
grid_size
=
gpu_config
.
GetGridSize
();
size_t
block_size
=
gpu_config
.
GetBlockSize
();
if
(
FLAGS_use_curand
)
{
int64_t
device_id
=
dev_ctx
.
GetPlace
().
GetDeviceId
();
const
auto
&
prop
=
platform
::
GetDeviceProperties
(
device_id
);
size_t
max_grid_size
=
prop
.
maxThreadsPerMultiProcessor
*
prop
.
multiProcessorCount
/
block_size
;
grid_size
=
std
::
min
(
grid_size
,
max_grid_size
);
}
int64_t
device_id
=
dev_ctx
.
GetPlace
().
GetDeviceId
();
const
auto
&
prop
=
platform
::
GetDeviceProperties
(
device_id
);
size_t
max_grid_size
=
prop
.
maxThreadsPerMultiProcessor
*
prop
.
multiProcessorCount
/
block_size
;
grid_size
=
std
::
min
(
grid_size
,
max_grid_size
);
auto
offset
=
((
x_numel
-
1
)
/
(
grid_size
*
block_size
*
kVecSize
)
+
1
)
*
kVecSize
;
...
...
@@ -268,7 +197,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
VectorizedRandomGenerator
<
T
,
uint8_t
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
size
,
seed_data
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
upscale_in_train
,
increment
,
main_offset
,
FLAGS_use_curand
);
upscale_in_train
,
increment
,
main_offset
);
}
else
{
if
(
upscale_in_train
)
{
// todo: can y share with data with x directly?
...
...
paddle/fluid/operators/gaussian_random_op.cu
浏览文件 @
9714878c
...
...
@@ -11,21 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
DECLARE_bool
(
use_curand
);
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/uniform_random_op.h
浏览文件 @
9714878c
...
...
@@ -19,11 +19,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#if defined(__NVCC__) || defined(__HIPCC__)
DECLARE_bool
(
use_curand
);
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
...
...
@@ -146,39 +142,6 @@ struct UniformGenerator {
}
};
template
<
typename
T
>
struct
UniformGeneratorOffset
{
T
min_
,
max_
;
unsigned
int
seed_
;
T
diag_val_
;
unsigned
int
diag_num_
;
unsigned
int
diag_step_
;
int
offset_
;
__host__
__device__
UniformGeneratorOffset
(
T
min
,
T
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
T
diag_val
,
int
offset
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
),
diag_num_
(
diag_num
),
diag_step_
(
diag_step
),
diag_val_
(
diag_val
),
offset_
(
offset
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
+
offset_
);
T
out
=
dist
(
rng
);
unsigned
int
remainder
=
n
%
(
diag_step_
+
1
);
if
(
remainder
==
0
&&
diag_num_
>
n
/
(
diag_step_
+
1
))
{
out
=
diag_val_
;
}
return
out
;
}
};
template
<
typename
T
>
void
UniformRandom
(
const
framework
::
ExecutionContext
&
context
,
framework
::
Tensor
*
tensor
)
{
...
...
@@ -205,19 +168,10 @@ void UniformRandom(const framework::ExecutionContext& context,
int
device_id
=
context
.
GetPlace
().
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
seed_flag
)
{
if
(
FLAGS_use_curand
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
phi
::
funcs
::
uniform_distribution
<
MT
>
dist
;
phi
::
funcs
::
uniform_real_transform
<
MT
>
trans
(
min
,
max
);
phi
::
funcs
::
distribution_and_transform
<
T
>
(
dev_cxt
,
tensor
,
dist
,
trans
);
}
else
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int64_t
gen_offset
=
size
*
seed_offset
.
second
;
auto
func
=
UniformGeneratorOffset
<
T
>
(
min
,
max
,
seed_offset
.
first
,
diag_num
,
diag_step
,
diag_val
,
gen_offset
);
phi
::
IndexKernel
<
T
,
UniformGeneratorOffset
<
T
>>
(
dev_cxt
,
tensor
,
func
);
}
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
phi
::
funcs
::
uniform_distribution
<
MT
>
dist
;
phi
::
funcs
::
uniform_real_transform
<
MT
>
trans
(
min
,
max
);
phi
::
funcs
::
distribution_and_transform
<
T
>
(
dev_cxt
,
tensor
,
dist
,
trans
);
}
else
{
auto
func
=
UniformGenerator
<
T
>
(
min
,
max
,
seed
,
diag_num
,
diag_step
,
diag_val
);
...
...
paddle/fluid/platform/flags.cc
浏览文件 @
9714878c
...
...
@@ -545,8 +545,6 @@ PADDLE_DEFINE_EXPORTED_double(
*/
PADDLE_DEFINE_EXPORTED_bool
(
use_mkldnn
,
false
,
"Use MKLDNN to run"
);
PADDLE_DEFINE_EXPORTED_bool
(
use_curand
,
false
,
"Random OP use CURAND"
);
/**
* Debug related FLAG
* Name: FLAGS_call_stack_level
...
...
paddle/phi/kernels/cpu/transpose_kernel.cc
浏览文件 @
9714878c
...
...
@@ -75,6 +75,7 @@ PD_REGISTER_KERNEL(transpose,
double
,
int32_t
,
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
paddle/phi/kernels/gpu/bernoulli_kernel.cu
浏览文件 @
9714878c
...
...
@@ -14,8 +14,6 @@
#include "paddle/phi/kernels/bernoulli_kernel.h"
#include <thrust/random.h>
#include <thrust/transform.h>
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
...
...
@@ -32,35 +30,8 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
template
<
typename
T
>
struct
BernoulliCudaFunctor
{
unsigned
int
seed_
;
unsigned
int
offset_
;
__host__
__device__
BernoulliCudaFunctor
(
unsigned
int
seed
,
unsigned
int
offset
)
:
seed_
(
seed
),
offset_
(
offset
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
,
const
T
p
)
const
{
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
// lines of error messages if, and it should be refined.
PADDLE_ENFORCE
(
p
>=
0.0
&&
p
<=
1.0
,
"The probability should be >=0 and <= 1, but got %f"
,
p
);
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
0.0
,
1.0
);
rng
.
discard
(
n
+
offset_
);
return
static_cast
<
T
>
(
dist
(
rng
)
<
p
);
}
};
// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time
template
<
typename
T
>
__global__
void
bernoulli_cuda_kernel
(
...
...
@@ -100,30 +71,16 @@ void BernoulliKernel(const Context& ctx,
auto
gen_cuda
=
ctx
.
GetGenerator
();
if
(
FLAGS_use_curand
)
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
12
);
uint64_t
seed
=
seed_offset
.
first
;
uint64_t
offset
=
seed_offset
.
second
;
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
12
);
uint64_t
seed
=
seed_offset
.
first
;
uint64_t
offset
=
seed_offset
.
second
;
auto
gpu_config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
numel
,
4
);
size_t
grid_size
=
gpu_config
.
GetGridSize
();
size_t
block_size
=
gpu_config
.
GetBlockSize
();
auto
gpu_config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
numel
,
4
);
size_t
grid_size
=
gpu_config
.
GetGridSize
();
size_t
block_size
=
gpu_config
.
GetBlockSize
();
bernoulli_cuda_kernel
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
numel
,
seed
,
offset
,
x_data
,
out_data
);
}
else
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int64_t
gen_offset
=
numel
*
seed_offset
.
second
;
paddle
::
platform
::
Transform
<
phi
::
GPUContext
>
trans
;
thrust
::
counting_iterator
<
int64_t
>
index_sequence_begin
(
0
);
trans
(
ctx
,
index_sequence_begin
,
index_sequence_begin
+
numel
,
x_data
,
out_data
,
BernoulliCudaFunctor
<
T
>
(
static_cast
<
int64_t
>
(
seed_offset
.
first
),
static_cast
<
int64_t
>
(
gen_offset
)));
}
bernoulli_cuda_kernel
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
numel
,
seed
,
offset
,
x_data
,
out_data
);
}
}
// namespace phi
...
...
paddle/phi/kernels/gpu/gaussian_random_kernel.cu
浏览文件 @
9714878c
...
...
@@ -14,10 +14,7 @@
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
...
...
@@ -27,8 +24,6 @@
#include "paddle/fluid/framework/generator.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
template
<
typename
T
>
...
...
@@ -83,21 +78,11 @@ void GaussianRandomKernel(const Context& dev_ctx,
auto
gen_cuda
=
paddle
::
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
seed_flag
)
{
if
(
FLAGS_use_curand
)
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
funcs
::
normal_distribution
<
MT
>
dist
;
funcs
::
normal_transform
<
MT
>
trans
(
static_cast
<
MT
>
(
mean
),
static_cast
<
MT
>
(
std
));
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
tensor
,
dist
,
trans
);
}
else
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int64_t
gen_offset
=
size
*
seed_offset
.
second
;
auto
func
=
GaussianGenerator
<
T
>
(
static_cast
<
T
>
(
mean
),
static_cast
<
T
>
(
std
),
seed_offset
.
first
,
gen_offset
);
IndexKernel
<
T
,
GaussianGenerator
<
T
>>
(
dev_ctx
,
tensor
,
func
);
}
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
funcs
::
normal_distribution
<
MT
>
dist
;
funcs
::
normal_transform
<
MT
>
trans
(
static_cast
<
MT
>
(
mean
),
static_cast
<
MT
>
(
std
));
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
tensor
,
dist
,
trans
);
}
else
{
auto
func
=
GaussianGenerator
<
T
>
(
static_cast
<
T
>
(
mean
),
static_cast
<
T
>
(
std
),
seed
);
...
...
paddle/phi/kernels/gpu/multinomial_kernel.cu
浏览文件 @
9714878c
...
...
@@ -18,11 +18,6 @@ limitations under the License. */
#include "paddle/phi/kernels/multinomial_kernel.h"
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
...
...
@@ -44,12 +39,6 @@ namespace cub = hipcub;
#include "paddle/phi/kernels/funcs/multinomial_functor.h"
#include "paddle/phi/kernels/top_k_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
template
<
typename
T
>
...
...
@@ -74,32 +63,6 @@ __global__ void NormalizeProbability(T* norm_probs,
}
}
template
<
typename
T
>
__global__
void
GetCumulativeProbs
(
T
*
norm_probs_data
,
int64_t
num_distributions
,
int64_t
num_categories
,
T
*
cumulative_probs_data
)
{
int
id
=
blockIdx
.
x
;
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
(
id
+
1
)
*
num_categories
,
cumulative_probs_data
+
id
*
num_categories
);
}
template
<
typename
T
>
struct
RandomGeneratorCudaFunctor
{
unsigned
int
seed_
;
__host__
__device__
RandomGeneratorCudaFunctor
(
int
seed
)
:
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
0.0
,
1.0
);
rng
.
discard
(
n
);
return
dist
(
rng
);
}
};
template
<
typename
T
>
__device__
int
binarySearchFunctor
(
T
*
cumulative_probs_data
,
T
*
norm_probs_data
,
...
...
@@ -130,7 +93,6 @@ __device__ int binarySearchFunctor(T* cumulative_probs_data,
template
<
typename
T
>
__global__
void
sampleMultinomialWithReplacement
(
T
*
rng_data
,
const
int64_t
num_samples
,
int64_t
*
out_data
,
const
int64_t
num_distributions
,
...
...
@@ -138,10 +100,9 @@ __global__ void sampleMultinomialWithReplacement(
T
*
cumulative_probs_data
,
T
*
norm_probs_data
,
uint64_t
seed
,
uint64_t
offset
,
bool
use_curand
)
{
uint64_t
offset
)
{
// use binary search to get the selected category sample id.
// let cumulative_probs_data[id-1] < rng_
data
< cumulative_probs_data[id].
// let cumulative_probs_data[id-1] < rng_
number
< cumulative_probs_data[id].
size_t
idx
=
gridDim
.
x
*
blockDim
.
x
*
blockIdx
.
y
+
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -151,10 +112,7 @@ __global__ void sampleMultinomialWithReplacement(
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
dist
=
blockIdx
.
y
;
dist
<
num_distributions
;
dist
+=
gridDim
.
y
)
{
if
(
sample
<
num_samples
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
if
(
use_curand
)
{
rng_number
=
static_cast
<
T
>
(
curand_uniform4
(
&
state
).
x
);
}
T
rng_number
=
static_cast
<
T
>
(
curand_uniform4
(
&
state
).
x
);
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs_data
+
dist
*
num_categories
,
...
...
@@ -182,10 +140,7 @@ void MultinomialKernel(const Context& dev_ctx,
const
int64_t
num_distributions
=
in_rank
>
1
?
in_dims
[
in_rank
-
2
]
:
1
;
// If replacement is False, it's not a replaceable sample. Every category
// can
// be used only once. So after every sample, probability of the distribution
// will change. The implementation can't be parallelizable. Thus, call CPU
// implementation ``funcs::MultinomialFunctor`` to sample the distribution.
// can be used only once.
if
(
!
replacement
)
{
int64_t
in_data_numel
=
x
.
numel
();
int64_t
out_data_numel
=
out
->
numel
();
...
...
@@ -202,76 +157,50 @@ void MultinomialKernel(const Context& dev_ctx,
in_data_numel
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
#endif
if
(
FLAGS_use_curand
)
{
for
(
size_t
i
=
0
;
i
<
num_distributions
;
++
i
)
{
int
zero_num
=
0
;
for
(
size_t
j
=
0
;
j
<
num_categories
;
++
j
)
{
T
weight
=
cpu_in_data
[
i
*
num_distributions
+
j
];
PADDLE_ENFORCE_GE
(
weight
,
0
,
errors
::
InvalidArgument
(
"Each element of multinomial'input must >= 0, but got %f."
,
weight
));
if
(
weight
==
static_cast
<
T
>
(
0
))
{
zero_num
++
;
}
for
(
size_t
i
=
0
;
i
<
num_distributions
;
++
i
)
{
int
zero_num
=
0
;
for
(
size_t
j
=
0
;
j
<
num_categories
;
++
j
)
{
T
weight
=
cpu_in_data
[
i
*
num_distributions
+
j
];
PADDLE_ENFORCE_GE
(
weight
,
0
,
errors
::
InvalidArgument
(
"Each element of multinomial'input must >= 0, but got %f."
,
weight
));
if
(
weight
==
static_cast
<
T
>
(
0
))
{
zero_num
++
;
}
int
valid_samples
=
num_categories
-
zero_num
;
PADDLE_ENFORCE_LE
(
num_samples
,
valid_samples
,
errors
::
InvalidArgument
(
"When replacement=False, 'num_samples' "
"must less than or eaqual to the number of "
"positive item of input"
));
}
// Refer to [gumbel softmax algorithm]
DenseTensor
rand
=
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
);
T
*
rand_data
=
rand
.
data
<
T
>
();
funcs
::
uniform_distribution
<
T
>
dist
;
funcs
::
exponential_transform
<
T
>
trans
(
1.0
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
&
rand
,
dist
,
trans
);
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
x
.
numel
());
for_range
([
rand_data
,
in_data
]
__device__
(
size_t
idx
)
{
rand_data
[
idx
]
=
in_data
[
idx
]
/
rand_data
[
idx
];
});
if
(
num_samples
==
1
)
{
ArgMaxKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
-
1
,
true
,
false
,
3
/*proto::VarType::INT64*/
,
out
);
}
else
{
std
::
vector
<
int64_t
>
out_dim_vec
=
vectorize
<
int64_t
>
(
out
->
dims
());
DenseTensor
value
=
Empty
<
T
,
Context
>
(
dev_ctx
,
IntArray
(
out_dim_vec
));
TopkKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
Scalar
(
num_samples
),
-
1
,
true
,
true
,
&
value
,
out
);
}
return
;
int
valid_samples
=
num_categories
-
zero_num
;
PADDLE_ENFORCE_LE
(
num_samples
,
valid_samples
,
errors
::
InvalidArgument
(
"When replacement=False, 'num_samples' "
"must less than or eaqual to the number of "
"positive item of input"
));
}
funcs
::
MultinomialFunctor
<
T
>
(
dev_ctx
,
cpu_out_data
,
cpu_in_data
,
num_samples
,
replacement
,
num_categories
,
num_distributions
);
#ifdef PADDLE_WITH_HIP
hipMemcpy
(
out_data
,
cpu_out_data
,
out_data_numel
*
sizeof
(
int64_t
),
hipMemcpyHostToDevice
);
#else
cudaMemcpy
(
out_data
,
cpu_out_data
,
out_data_numel
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
#endif
delete
[]
cpu_in_data
;
delete
[]
cpu_out_data
;
// Refer to [gumbel softmax algorithm]
DenseTensor
rand
=
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
);
T
*
rand_data
=
rand
.
data
<
T
>
();
funcs
::
uniform_distribution
<
T
>
dist
;
funcs
::
exponential_transform
<
T
>
trans
(
1.0
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
&
rand
,
dist
,
trans
);
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
x
.
numel
());
for_range
([
rand_data
,
in_data
]
__device__
(
size_t
idx
)
{
rand_data
[
idx
]
=
in_data
[
idx
]
/
rand_data
[
idx
];
});
if
(
num_samples
==
1
)
{
ArgMaxKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
-
1
,
true
,
false
,
3
/*proto::VarType::INT64*/
,
out
);
}
else
{
std
::
vector
<
int64_t
>
out_dim_vec
=
vectorize
<
int64_t
>
(
out
->
dims
());
DenseTensor
value
=
Empty
<
T
,
Context
>
(
dev_ctx
,
IntArray
(
out_dim_vec
));
TopkKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
Scalar
(
num_samples
),
-
1
,
true
,
true
,
&
value
,
out
);
}
return
;
}
...
...
@@ -322,44 +251,18 @@ void MultinomialKernel(const Context& dev_ctx,
auto
*
cumulative_probs_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
cumulative_probs_tensor
);
if
(
FLAGS_use_curand
)
{
// 'phi::funcs::InclusiveScan' has higher accuracy than
// 'thrust::inclusive_scan'
funcs
::
InclusiveScan
<
T
,
std
::
plus
<
T
>>
(
/*in*/
norm_probs_data
,
/*out*/
cumulative_probs_data
,
/*outer_dim*/
static_cast
<
size_t
>
(
num_distributions
),
/*mid_dim*/
static_cast
<
size_t
>
(
num_categories
),
/*inner_dim*/
static_cast
<
size_t
>
(
1
),
/*init*/
static_cast
<
T
>
(
0
),
std
::
plus
<
T
>
(),
/*reverse=*/
false
,
dev_ctx
);
}
else
{
dim3
block_cumsum
(
1
);
dim3
grid_cumsum
(
num_distributions
);
GetCumulativeProbs
<
T
><<<
grid_cumsum
,
block_cumsum
,
0
,
dev_ctx
.
stream
()
>>>
(
norm_probs_data
,
num_distributions
,
num_categories
,
cumulative_probs_data
);
}
// Generate random number for each sample.
std
::
random_device
rd
;
auto
seed
=
rd
();
DenseTensor
rng_data_tensor
;
rng_data_tensor
.
Resize
({
num_distributions
,
num_samples
});
auto
*
rng_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
rng_data_tensor
);
thrust
::
counting_iterator
<
int64_t
>
index_sequence_begin
(
0
);
paddle
::
platform
::
Transform
<
GPUContext
>
trans
;
trans
(
dev_ctx
,
index_sequence_begin
,
index_sequence_begin
+
num_distributions
*
num_samples
,
rng_data
,
RandomGeneratorCudaFunctor
<
T
>
(
seed
));
// 'phi::funcs::InclusiveScan' has higher accuracy than
// 'thrust::inclusive_scan'
funcs
::
InclusiveScan
<
T
,
std
::
plus
<
T
>>
(
/*in*/
norm_probs_data
,
/*out*/
cumulative_probs_data
,
/*outer_dim*/
static_cast
<
size_t
>
(
num_distributions
),
/*mid_dim*/
static_cast
<
size_t
>
(
num_categories
),
/*inner_dim*/
static_cast
<
size_t
>
(
1
),
/*init*/
static_cast
<
T
>
(
0
),
std
::
plus
<
T
>
(),
/*reverse=*/
false
,
dev_ctx
);
// Sample the multinomial distributions.
dim3
block
(
128
);
...
...
@@ -376,7 +279,6 @@ void MultinomialKernel(const Context& dev_ctx,
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
increment
);
sampleMultinomialWithReplacement
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
rng_data
,
num_samples
,
out_data
,
num_distributions
,
...
...
@@ -384,8 +286,7 @@ void MultinomialKernel(const Context& dev_ctx,
cumulative_probs_data
,
norm_probs_data
,
seed_offset
.
first
,
seed_offset
.
second
,
FLAGS_use_curand
);
seed_offset
.
second
);
}
}
// namespace phi
...
...
paddle/phi/kernels/gpu/randint_kernel.cu
浏览文件 @
9714878c
...
...
@@ -23,8 +23,6 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
template
<
typename
T
,
typename
Context
>
...
...
@@ -37,37 +35,9 @@ void RandintRawKernel(const Context& dev_ctx,
DenseTensor
*
out
)
{
out
->
Resize
(
phi
::
make_ddim
(
shape
.
GetData
()));
T
*
data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
if
(
FLAGS_use_curand
)
{
funcs
::
uniform_distribution
<
uint32_t
>
dist
;
funcs
::
uniform_int_transform
<
T
,
uint32_t
>
trans
(
low
,
high
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
out
,
dist
,
trans
);
}
else
{
DenseTensor
tmp
;
tmp
.
Resize
(
phi
::
make_ddim
(
shape
.
GetData
()));
T
*
tmp_data
=
dev_ctx
.
template
HostAlloc
<
T
>(
&
tmp
);
std
::
shared_ptr
<
std
::
mt19937_64
>
engine
;
if
(
seed
)
{
engine
=
std
::
make_shared
<
std
::
mt19937_64
>
();
engine
->
seed
(
seed
);
}
else
{
engine
=
dev_ctx
.
GetHostGenerator
()
->
GetCPUEngine
();
}
std
::
uniform_int_distribution
<
T
>
dist
(
low
,
high
-
1
);
auto
numel
=
out
->
numel
();
for
(
int64_t
i
=
0
;
i
<
numel
;
++
i
)
{
tmp_data
[
i
]
=
dist
(
*
engine
);
}
paddle
::
memory
::
Copy
<
phi
::
GPUPlace
,
phi
::
Place
>
(
out
->
place
(),
data
,
tmp
.
place
(),
tmp_data
,
numel
*
paddle
::
experimental
::
SizeOf
(
out
->
dtype
()),
0
);
}
funcs
::
uniform_distribution
<
uint32_t
>
dist
;
funcs
::
uniform_int_transform
<
T
,
uint32_t
>
trans
(
low
,
high
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
out
,
dist
,
trans
);
}
template
<
typename
T
,
typename
Context
>
...
...
paddle/phi/kernels/gpu/randperm_kernel.cu
浏览文件 @
9714878c
...
...
@@ -84,91 +84,65 @@ __global__ void SwapRepeatKernel(
template
<
typename
T
,
typename
Context
>
void
RandpermRawKernel
(
const
Context
&
dev_ctx
,
int
n
,
DataType
dtype
,
int
seed
,
DenseTensor
*
out
)
{
if
(
FLAGS_use_curand
)
{
DenseTensor
key
;
RandintKernel
<
int
,
Context
>
(
dev_ctx
,
std
::
numeric_limits
<
int
>::
min
(),
std
::
numeric_limits
<
int
>::
max
(),
IntArray
({
n
}),
phi
::
DataType
::
INT32
,
&
key
);
DenseTensor
key_out
=
Empty
<
int
,
Context
>
(
dev_ctx
,
IntArray
({
n
}));
DenseTensor
range
=
Empty
<
T
,
Context
>
(
dev_ctx
,
IntArray
({
n
}));
T
*
range_data
=
range
.
data
<
T
>
();
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
n
);
for_range
([
range_data
]
__device__
(
size_t
idx
)
{
range_data
[
idx
]
=
static_cast
<
T
>
(
idx
);
});
out
->
Resize
(
phi
::
make_ddim
({
n
}));
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
// Refer to [Algorithm of randperm] https://osf.io/af2hy/ to
// improve performance of radix sort.
double
n_d
=
static_cast
<
double
>
(
n
);
int
begin_bit
=
0
;
int
end_bit
=
std
::
ceil
(
std
::
log2
(
n_d
-
(
6
*
n_d
*
n_d
+
1
)
/
(
12
*
std
::
log
(
0.9
))));
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
T
>
(
nullptr
,
temp_storage_bytes
,
key
.
data
<
int
>
(),
key_out
.
data
<
int
>
(),
range
.
data
<
T
>
(),
out_data
,
n
,
begin_bit
,
end_bit
<
32
?
end_bit
:
32
,
dev_ctx
.
stream
());
auto
d_temp_storage
=
paddle
::
memory
::
Alloc
(
dev_ctx
,
temp_storage_bytes
);
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
T
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
key
.
data
<
int
>
(),
key_out
.
data
<
int
>
(),
range
.
data
<
T
>
(),
out_data
,
n
,
begin_bit
,
end_bit
<
32
?
end_bit
:
32
,
dev_ctx
.
stream
());
auto
gen_cuda
=
dev_ctx
.
GetGenerator
();
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
n
);
uint64_t
seed
=
seed_offset
.
first
;
uint64_t
offset
=
seed_offset
.
second
;
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
);
SwapRepeatKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
key_out
.
data
<
int
>
(),
out_data
,
n
,
seed
,
offset
);
}
else
{
DenseTensor
tmp
;
tmp
.
Resize
(
phi
::
make_ddim
({
n
}));
T
*
tmp_data
=
dev_ctx
.
template
HostAlloc
<
T
>(
&
tmp
);
std
::
shared_ptr
<
std
::
mt19937_64
>
engine
;
if
(
seed
)
{
engine
=
std
::
make_shared
<
std
::
mt19937_64
>
();
engine
->
seed
(
seed
);
}
else
{
engine
=
dev_ctx
.
GetHostGenerator
()
->
GetCPUEngine
();
}
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
tmp_data
[
i
]
=
static_cast
<
T
>
(
i
);
}
std
::
shuffle
(
tmp_data
,
tmp_data
+
n
,
*
engine
);
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
size
=
out
->
numel
()
*
paddle
::
experimental
::
SizeOf
(
out
->
dtype
());
paddle
::
memory
::
Copy
<
phi
::
GPUPlace
,
phi
::
Place
>
(
out
->
place
(),
out_data
,
tmp
.
place
(),
tmp_data
,
size
,
0
);
}
DenseTensor
key
;
RandintKernel
<
int
,
Context
>
(
dev_ctx
,
std
::
numeric_limits
<
int
>::
min
(),
std
::
numeric_limits
<
int
>::
max
(),
IntArray
({
n
}),
phi
::
DataType
::
INT32
,
&
key
);
DenseTensor
key_out
=
Empty
<
int
,
Context
>
(
dev_ctx
,
IntArray
({
n
}));
DenseTensor
range
=
Empty
<
T
,
Context
>
(
dev_ctx
,
IntArray
({
n
}));
T
*
range_data
=
range
.
data
<
T
>
();
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
n
);
for_range
([
range_data
]
__device__
(
size_t
idx
)
{
range_data
[
idx
]
=
static_cast
<
T
>
(
idx
);
});
out
->
Resize
(
phi
::
make_ddim
({
n
}));
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
// Refer to [Algorithm of randperm] https://osf.io/af2hy/ to
// improve performance of radix sort.
double
n_d
=
static_cast
<
double
>
(
n
);
int
begin_bit
=
0
;
int
end_bit
=
std
::
ceil
(
std
::
log2
(
n_d
-
(
6
*
n_d
*
n_d
+
1
)
/
(
12
*
std
::
log
(
0.9
))));
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
T
>
(
nullptr
,
temp_storage_bytes
,
key
.
data
<
int
>
(),
key_out
.
data
<
int
>
(),
range
.
data
<
T
>
(),
out_data
,
n
,
begin_bit
,
end_bit
<
32
?
end_bit
:
32
,
dev_ctx
.
stream
());
auto
d_temp_storage
=
paddle
::
memory
::
Alloc
(
dev_ctx
,
temp_storage_bytes
);
cub
::
DeviceRadixSort
::
SortPairs
<
int
,
T
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
key
.
data
<
int
>
(),
key_out
.
data
<
int
>
(),
range
.
data
<
T
>
(),
out_data
,
n
,
begin_bit
,
end_bit
<
32
?
end_bit
:
32
,
dev_ctx
.
stream
());
auto
gen_cuda
=
dev_ctx
.
GetGenerator
();
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
n
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
);
SwapRepeatKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
key_out
.
data
<
int
>
(),
out_data
,
n
,
seed_offset
.
first
,
seed_offset
.
second
);
}
template
<
typename
T
,
typename
Context
>
...
...
paddle/phi/kernels/gpu/uniform_random_kernel.cu
浏览文件 @
9714878c
...
...
@@ -14,14 +14,13 @@
#include "paddle/phi/kernels/uniform_random_kernel.h"
#include <thrust/random.h>
#include "gflags/gflags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
template
<
typename
T
>
...
...
@@ -54,43 +53,6 @@ struct UniformGenerator {
}
};
template
<
typename
T
>
struct
UniformGeneratorOffset
{
T
min_
,
max_
;
unsigned
int
seed_
;
T
diag_val_
;
unsigned
int
diag_num_
;
unsigned
int
diag_step_
;
int
offset_
;
__host__
__device__
UniformGeneratorOffset
(
T
min
,
T
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
T
diag_val
,
int
offset
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
),
diag_num_
(
diag_num
),
diag_step_
(
diag_step
),
diag_val_
(
diag_val
),
offset_
(
offset
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
+
offset_
);
T
out
=
dist
(
rng
);
unsigned
int
remainder
=
n
%
(
diag_step_
+
1
);
if
(
remainder
==
0
&&
diag_num_
>
n
/
(
diag_step_
+
1
))
{
out
=
diag_val_
;
}
return
out
;
}
};
template
<
typename
T
,
typename
Context
>
void
UniformRandomRawKernel
(
const
Context
&
dev_ctx
,
const
IntArray
&
shape
,
...
...
@@ -114,23 +76,10 @@ void UniformRandomRawKernel(const Context& dev_ctx,
auto
generator
=
dev_ctx
.
GetGenerator
();
if
(
generator
->
GetIsInitPy
()
&&
seed_flag
)
{
if
(
FLAGS_use_curand
)
{
using
MT
=
typename
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
funcs
::
uniform_distribution
<
MT
>
dist
;
funcs
::
uniform_real_transform
<
MT
>
trans
(
min
,
max
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
out
,
dist
,
trans
);
}
else
{
auto
seed_offset
=
generator
->
IncrementOffset
(
1
);
int64_t
gen_offset
=
size
*
seed_offset
.
second
;
auto
func
=
UniformGeneratorOffset
<
T
>
(
min
,
max
,
seed_offset
.
first
,
diag_num
,
diag_step
,
diag_val
,
gen_offset
);
IndexKernel
<
T
,
UniformGeneratorOffset
<
T
>>
(
dev_ctx
,
out
,
func
);
}
using
MT
=
typename
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
funcs
::
uniform_distribution
<
MT
>
dist
;
funcs
::
uniform_real_transform
<
MT
>
trans
(
min
,
max
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
out
,
dist
,
trans
);
}
else
{
auto
func
=
UniformGenerator
<
T
>
(
min
,
max
,
seed
,
diag_num
,
diag_step
,
diag_val
);
...
...
paddle/scripts/paddle_build.bat
浏览文件 @
9714878c
...
...
@@ -657,7 +657,6 @@ for /F %%# in ('wmic os get localdatetime^|findstr 20') do set start=%%#
set
start
=
%start
:
~
4
,
10
%
set
FLAGS_call_stack_level
=
2
set
FLAGS_use_curand
=
True
dir
%THIRD
_PARTY_PATH:/
=
\
%
\install\openblas\lib
dir
%THIRD
_PARTY_PATH:/
=
\
%
\install\openblas\bin
dir
%THIRD
_PARTY_PATH:/
=
\
%
\install\zlib\bin
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
9714878c
...
...
@@ -61,8 +61,6 @@ function init() {
# NOTE(chenweihang): For easy debugging, CI displays the C++ error stacktrace by default
export
FLAGS_call_stack_level
=
2
export
FLAGS_use_curand
=
True
# set CI_SKIP_CPP_TEST if only *.py changed
# In order to avoid using in some CI(such as daily performance), the current
# branch must not be `${BRANCH}` which is usually develop.
...
...
python/paddle/fluid/initializer.py
浏览文件 @
9714878c
...
...
@@ -561,12 +561,12 @@ class XavierInitializer(Initializer):
if
framework
.
_non_static_mode
():
if
self
.
_uniform
:
limit
=
np
.
sqrt
(
6.0
/
float
(
fan_in
+
fan_out
))
limit
=
math
.
sqrt
(
6.0
/
float
(
fan_in
+
fan_out
))
out_var
=
_C_ops
.
uniform_random
(
'shape'
,
out_var
.
shape
,
'min'
,
-
limit
,
'max'
,
limit
,
'seed'
,
self
.
_seed
,
'dtype'
,
out_dtype
)
else
:
std
=
np
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
std
=
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
out_var
=
_C_ops
.
gaussian_random
(
'shape'
,
out_var
.
shape
,
'dtype'
,
out_dtype
,
'mean'
,
0.0
,
'std'
,
std
,
'seed'
,
self
.
_seed
)
...
...
@@ -581,7 +581,7 @@ class XavierInitializer(Initializer):
return
None
else
:
if
self
.
_uniform
:
limit
=
np
.
sqrt
(
6.0
/
float
(
fan_in
+
fan_out
))
limit
=
math
.
sqrt
(
6.0
/
float
(
fan_in
+
fan_out
))
op
=
block
.
append_op
(
type
=
"uniform_random"
,
inputs
=
{},
...
...
@@ -595,7 +595,7 @@ class XavierInitializer(Initializer):
},
stop_gradient
=
True
)
else
:
std
=
np
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
std
=
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
op
=
block
.
append_op
(
type
=
"gaussian_random"
,
outputs
=
{
"Out"
:
out_var
},
...
...
@@ -713,13 +713,13 @@ class MSRAInitializer(Initializer):
if
framework
.
_non_static_mode
():
if
self
.
_uniform
:
limit
=
np
.
sqrt
(
6.0
/
float
(
fan_in
))
limit
=
math
.
sqrt
(
6.0
/
float
(
fan_in
))
out_var
=
_C_ops
.
uniform_random
(
'shape'
,
out_var
.
shape
,
'min'
,
-
limit
,
'max'
,
limit
,
'seed'
,
self
.
_seed
,
'dtype'
,
int
(
out_dtype
))
else
:
std
=
np
.
sqrt
(
2.0
/
float
(
fan_in
))
std
=
math
.
sqrt
(
2.0
/
float
(
fan_in
))
out_var
=
_C_ops
.
gaussian_random
(
'shape'
,
out_var
.
shape
,
'dtype'
,
int
(
out_dtype
),
'mean'
,
0.0
,
'std'
,
std
,
'seed'
,
self
.
_seed
)
...
...
@@ -734,7 +734,7 @@ class MSRAInitializer(Initializer):
return
None
else
:
if
self
.
_uniform
:
limit
=
np
.
sqrt
(
6.0
/
float
(
fan_in
))
limit
=
math
.
sqrt
(
6.0
/
float
(
fan_in
))
op
=
block
.
append_op
(
type
=
"uniform_random"
,
inputs
=
{},
...
...
@@ -749,7 +749,7 @@ class MSRAInitializer(Initializer):
stop_gradient
=
True
)
else
:
std
=
np
.
sqrt
(
2.0
/
float
(
fan_in
))
std
=
math
.
sqrt
(
2.0
/
float
(
fan_in
))
op
=
block
.
append_op
(
type
=
"gaussian_random"
,
outputs
=
{
"Out"
:
out_var
},
...
...
python/paddle/fluid/tests/unittests/test_bernoulli_op.py
浏览文件 @
9714878c
...
...
@@ -75,9 +75,6 @@ class TestRandomValue(unittest.TestCase):
if
not
paddle
.
is_compiled_with_cuda
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
9714878c
...
...
@@ -1034,9 +1034,6 @@ class TestRandomValue(unittest.TestCase):
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
...
...
python/paddle/fluid/tests/unittests/test_exponential_op.py
浏览文件 @
9714878c
...
...
@@ -100,9 +100,6 @@ class TestExponentialAPI(unittest.TestCase):
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
...
...
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
浏览文件 @
9714878c
...
...
@@ -342,9 +342,6 @@ class TestRandomValue(unittest.TestCase):
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
def
_check_random_value
(
dtype
,
expect
,
expect_mean
,
expect_std
):
x
=
paddle
.
randn
([
32
,
3
,
1024
,
1024
],
dtype
=
dtype
)
actual
=
x
.
numpy
()
...
...
python/paddle/fluid/tests/unittests/test_linear.py
浏览文件 @
9714878c
...
...
@@ -73,6 +73,22 @@ class LinearTestCase(unittest.TestCase):
np
.
testing
.
assert_array_almost_equal
(
res_f
,
res_nn
)
np
.
testing
.
assert_array_almost_equal
(
res_nn
,
res_np
)
def
test_weight_init
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
paddle
.
seed
(
100
)
linear
=
paddle
.
nn
.
Linear
(
2
,
3
,
weight_attr
=
paddle
.
nn
.
initializer
.
Normal
(
0
,
1.
))
paddle
.
nn
.
utils
.
_stride_column
(
linear
.
weight
)
expect
=
[[
1.4349908
,
-
0.8099171
,
-
2.64788
],
[
-
1.4981681
,
-
1.1784115
,
-
0.023253186
]]
self
.
assertTrue
(
np
.
allclose
(
linear
.
weight
.
numpy
(),
expect
))
linear
=
paddle
.
nn
.
Linear
(
2
,
3
)
expect
=
[[
0.73261100
,
0.43836895
,
0.07908206
],
[
0.85075015
,
-
1.04724526
,
0.64371765
]]
self
.
assertTrue
(
np
.
allclose
(
linear
.
weight
.
numpy
(),
expect
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multinomial_op.py
浏览文件 @
9714878c
...
...
@@ -227,9 +227,6 @@ class TestRandomValue(unittest.TestCase):
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
...
...
python/paddle/fluid/tests/unittests/test_poisson_op.py
浏览文件 @
9714878c
...
...
@@ -107,9 +107,6 @@ class TestPoissonAPI(unittest.TestCase):
if
not
paddle
.
is_compiled_with_cuda
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
...
...
python/paddle/fluid/tests/unittests/test_randint_op.py
浏览文件 @
9714878c
...
...
@@ -198,9 +198,6 @@ class TestRandomValue(unittest.TestCase):
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on GPU------>"
)
paddle
.
disable_static
()
...
...
python/paddle/fluid/tests/unittests/test_randperm_op.py
浏览文件 @
9714878c
...
...
@@ -155,9 +155,6 @@ class TestRandomValue(unittest.TestCase):
if
not
paddle
.
is_compiled_with_cuda
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
...
...
python/paddle/fluid/tests/unittests/test_uniform_random_op.py
浏览文件 @
9714878c
...
...
@@ -573,37 +573,46 @@ class TestRandomValue(unittest.TestCase):
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
def
_check_random_value
(
dtype
,
expect
,
expect_mean
,
expect_std
):
x
=
paddle
.
rand
([
32
,
3
,
1024
,
1024
],
dtype
=
dtype
)
actual
=
x
.
numpy
()
self
.
assertTrue
(
np
.
allclose
(
actual
[
2
,
1
,
512
,
1000
:
1010
],
expect
))
self
.
assertEqual
(
np
.
mean
(
actual
),
expect_mean
)
self
.
assertEqual
(
np
.
std
(
actual
),
expect_std
)
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
paddle
.
seed
(
2021
)
expect_mean
=
0.50000454338820143895816272561205551028251647949218750
expect_std
=
0.28867379167297479991560749112977646291255950927734375
expect
=
[
0.55298901
,
0.65184678
,
0.49375412
,
0.57943639
,
0.16459608
,
0.67181056
,
0.03021481
,
0.0238559
,
0.07742096
,
0.55972187
]
expect_mean
=
0.50000454338820143895816272561205551028251647949218750
expect_std
=
0.28867379167297479991560749112977646291255950927734375
_check_random_value
(
core
.
VarDesc
.
VarType
.
FP64
,
expect
,
expect_mean
,
expect_std
)
out
=
paddle
.
rand
([
32
,
3
,
1024
,
1024
],
dtype
=
'float64'
).
numpy
()
self
.
assertEqual
(
np
.
mean
(
out
),
expect_mean
)
self
.
assertEqual
(
np
.
std
(
out
),
expect_std
)
self
.
assertTrue
(
np
.
allclose
(
out
[
2
,
1
,
512
,
1000
:
1010
],
expect
)
)
expect_mean
=
0.50002604722976684570312500
expect_std
=
0.2886914908885955810546875
expect
=
[
0.45320973
,
0.17582087
,
0.725341
,
0.30849215
,
0.622257
,
0.46352342
,
0.97228295
,
0.12771158
,
0.286525
,
0.9810645
]
expect_mean
=
0.50002604722976684570312500
expect_std
=
0.2886914908885955810546875
_check_random_value
(
core
.
VarDesc
.
VarType
.
FP32
,
expect
,
expect_mean
,
expect_std
)
out
=
paddle
.
rand
([
32
,
3
,
1024
,
1024
],
dtype
=
'float32'
).
numpy
()
self
.
assertEqual
(
np
.
mean
(
out
),
expect_mean
)
self
.
assertEqual
(
np
.
std
(
out
),
expect_std
)
self
.
assertTrue
(
np
.
allclose
(
out
[
2
,
1
,
512
,
1000
:
1010
],
expect
))
expect_mean
=
25.11843109130859375
expect_std
=
43.370647430419921875
expect
=
[
30.089634
,
77.05225
,
3.1201615
,
68.34072
,
59.266724
,
-
25.33281
,
12.973292
,
27.41127
,
-
17.412298
,
27.931019
]
out
=
paddle
.
empty
(
[
16
,
16
,
16
,
16
],
dtype
=
'float32'
).
uniform_
(
-
50
,
100
).
numpy
()
self
.
assertEqual
(
np
.
mean
(
out
),
expect_mean
)
self
.
assertEqual
(
np
.
std
(
out
),
expect_std
)
self
.
assertTrue
(
np
.
allclose
(
out
[
10
,
10
,
10
,
0
:
10
],
expect
))
paddle
.
enable_static
()
...
...
python/paddle/nn/utils/__init__.py
浏览文件 @
9714878c
...
...
@@ -14,7 +14,7 @@
from
.spectral_norm_hook
import
spectral_norm
from
.weight_norm_hook
import
weight_norm
,
remove_weight_norm
# noqa: F401
from
.transform_parameters
import
parameters_to_vector
,
vector_to_parameters
# noqa: F401
from
.transform_parameters
import
parameters_to_vector
,
vector_to_parameters
,
_stride_column
# noqa: F401
__all__
=
[
#noqa
'weight_norm'
,
'remove_weight_norm'
,
'spectral_norm'
,
'parameters_to_vector'
,
'vector_to_parameters'
...
...
python/paddle/nn/utils/transform_parameters.py
浏览文件 @
9714878c
...
...
@@ -36,6 +36,39 @@ def _inplace_reshape_dygraph(x, shape):
stop_gradient
=
True
)
@
dygraph_only
def
_stride_column
(
param
):
"""
A tool function. Permute date of parameter as a 'columns' stride. Now, it only support 2-D parameter.
Args:
param(Tensor]): The param that will be strided according to 'columns'.
Examples:
.. code-block:: python
import paddle
paddle.seed(100)
linear = paddle.nn.Linear(2, 3)
print(linear.weight)
# [[-0.31485492, -1.02896988, 0.45741916],
# [-0.65525872, -1.04643178, 1.07262802]]
paddle.nn.utils.stride_column(linear.weight)
print(linear.weight)
# [[-0.31485492, 0.45741916, -1.04643178],
# [-1.02896988, -0.65525872, 1.07262802]]
"""
assert
len
(
param
.
shape
)
==
2
shape
=
[
param
.
shape
[
1
],
param
.
shape
[
0
]]
with
paddle
.
fluid
.
dygraph
.
no_grad
():
reshape_var
=
paddle
.
reshape
(
param
,
shape
)
transpose_var
=
paddle
.
transpose
(
reshape_var
,
[
1
,
0
])
transpose_var
.
_share_underline_tensor_to
(
param
)
@
dygraph_only
def
parameters_to_vector
(
parameters
,
name
=
None
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录