Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
58970995
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
58970995
编写于
3月 23, 2022
作者:
zhouweiwei2014
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change CUDA implementation of multinomial OP (#40752)
上级
95d3ebc8
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
488 addition
and
44 deletion
+488
-44
paddle/phi/kernels/funcs/distribution_helper.h
paddle/phi/kernels/funcs/distribution_helper.h
+8
-4
paddle/phi/kernels/funcs/inclusive_scan.h
paddle/phi/kernels/funcs/inclusive_scan.h
+274
-0
paddle/phi/kernels/gpu/multinomial_kernel.cu
paddle/phi/kernels/gpu/multinomial_kernel.cu
+151
-40
python/paddle/fluid/tests/unittests/test_multinomial_op.py
python/paddle/fluid/tests/unittests/test_multinomial_op.py
+55
-0
未找到文件。
paddle/phi/kernels/funcs/distribution_helper.h
浏览文件 @
58970995
...
...
@@ -50,11 +50,15 @@ struct exponential_transform {
HOSTDEVICE
inline
T
operator
()(
T
val
)
const
{
#if defined(__NVCC__) || defined(__HIPCC__)
T
log
=
-
std
::
numeric_limits
<
T
>::
epsilon
()
/
2
;
if
(
val
<
static_cast
<
T
>
(
1.
)
-
std
::
numeric_limits
<
T
>::
epsilon
()
/
2
)
{
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
log
(
val
);
log
=
logf
(
val
);
}
else
{
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
__logf
(
val
);
log
=
__logf
(
val
);
}
}
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
log
;
#else
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
std
::
log
(
static_cast
<
T
>
(
1.0
)
-
val
);
#endif
...
...
paddle/phi/kernels/funcs/inclusive_scan.h
0 → 100644
浏览文件 @
58970995
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include <thrust/device_ptr.h>
#include <thrust/iterator/reverse_iterator.h>
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/malloc.h"
namespace
phi
{
namespace
funcs
{
template
<
typename
T
>
struct
IsComplex
:
public
std
::
false_type
{};
template
<
>
struct
IsComplex
<::
phi
::
dtype
::
complex
<
float
>>
:
public
std
::
true_type
{};
template
<
>
struct
IsComplex
<::
phi
::
dtype
::
complex
<
double
>>
:
public
std
::
true_type
{};
template
<
typename
InputIterator
,
typename
OutputIterator
,
typename
BinaryOp
>
static
void
CubInclusiveScan
(
InputIterator
x_iter
,
OutputIterator
y_iter
,
size_t
n
,
BinaryOp
op
,
const
phi
::
GPUContext
&
dev_ctx
)
{
paddle
::
memory
::
allocation
::
AllocationPtr
allocation
;
void
*
temp_storage
=
nullptr
;
size_t
temp_storage_bytes
=
0
;
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceScan
::
InclusiveScan
(
temp_storage
,
temp_storage_bytes
,
x_iter
,
y_iter
,
op
,
static_cast
<
int
>
(
n
),
dev_ctx
.
stream
()));
if
(
i
==
0
&&
temp_storage_bytes
>
0
)
{
allocation
=
paddle
::
memory
::
Alloc
(
dev_ctx
.
GetPlace
(),
temp_storage_bytes
);
temp_storage
=
allocation
->
ptr
();
}
}
}
template
<
typename
T
>
static
auto
MakeThrustReverseIterator
(
T
*
x
)
{
return
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
T
>>
(
thrust
::
device_pointer_cast
(
x
));
}
template
<
typename
T
,
typename
BinaryOp
,
bool
kReverse
>
struct
InclusiveScanOuterOrMidDimFunctor
{
HOSTDEVICE
InclusiveScanOuterOrMidDimFunctor
(
const
T
*
x
,
T
*
y
,
size_t
mid_dim
,
size_t
inner_dim
,
T
init
,
BinaryOp
op
)
:
x_
(
x
),
y_
(
y
),
mid_dim_
(
mid_dim
),
inner_dim_
(
inner_dim
),
init_
(
init
),
op_
(
op
)
{}
HOSTDEVICE
void
operator
()(
size_t
idx
)
const
{
auto
outer_idx
=
idx
/
inner_dim_
;
auto
inner_idx
=
idx
%
inner_dim_
;
if
(
kReverse
)
{
idx
=
outer_idx
*
mid_dim_
*
inner_dim_
+
(
mid_dim_
-
1
)
*
inner_dim_
+
inner_idx
;
}
else
{
idx
=
outer_idx
*
mid_dim_
*
inner_dim_
+
inner_idx
;
}
auto
x_ptr
=
x_
+
idx
;
auto
y_ptr
=
y_
+
idx
;
T
acc_value
=
init_
;
for
(
size_t
i
=
0
;
i
<
mid_dim_
;
++
i
)
{
acc_value
=
op_
(
acc_value
,
*
x_ptr
);
*
y_ptr
=
acc_value
;
if
(
kReverse
)
{
x_ptr
-=
inner_dim_
;
y_ptr
-=
inner_dim_
;
}
else
{
x_ptr
+=
inner_dim_
;
y_ptr
+=
inner_dim_
;
}
}
}
private:
const
T
*
x_
;
T
*
y_
;
size_t
mid_dim_
;
size_t
inner_dim_
;
T
init_
;
BinaryOp
op_
;
};
template
<
typename
T
,
typename
BinaryOp
,
size_t
kThreadNumX
,
size_t
kThreadNumY
,
bool
kReverse
>
static
__global__
void
InclusiveScanInnerDimCUDAKernel
(
const
T
*
x
,
T
*
y
,
size_t
num_rows
,
size_t
row_size
,
T
init
,
BinaryOp
op
)
{
using
RealT
=
phi
::
dtype
::
Real
<
T
>
;
constexpr
auto
kSharedBufferSize
=
IsComplex
<
T
>::
value
?
4
*
kThreadNumX
:
2
*
kThreadNumX
;
__shared__
RealT
sbuf
[
kThreadNumY
][
kSharedBufferSize
];
T
*
row_buf
=
reinterpret_cast
<
T
*>
(
sbuf
[
threadIdx
.
y
]);
size_t
block_row
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
kThreadNumY
);
size_t
block_row_stride
=
static_cast
<
size_t
>
(
gridDim
.
x
*
kThreadNumY
);
for
(;
block_row
<
num_rows
;
block_row
+=
block_row_stride
)
{
size_t
row
=
block_row
+
threadIdx
.
y
;
T
block_total
=
init
;
const
T
*
row_x
=
x
+
row
*
row_size
;
T
*
row_y
=
y
+
row
*
row_size
;
for
(
size_t
block_col
=
0
;
block_col
<
row_size
;
block_col
+=
2
*
kThreadNumX
)
{
size_t
col1
,
col2
;
if
(
kReverse
)
{
col1
=
row_size
-
1
-
block_col
-
threadIdx
.
x
;
col2
=
col1
-
kThreadNumX
;
}
else
{
col1
=
block_col
+
threadIdx
.
x
;
col2
=
col1
+
kThreadNumX
;
}
if
(
row
<
num_rows
)
{
if
(
col1
<
row_size
)
{
row_buf
[
threadIdx
.
x
]
=
row_x
[
col1
];
}
else
{
row_buf
[
threadIdx
.
x
]
=
init
;
}
if
(
col2
<
row_size
)
{
row_buf
[
kThreadNumX
+
threadIdx
.
x
]
=
row_x
[
col2
];
}
else
{
row_buf
[
kThreadNumX
+
threadIdx
.
x
]
=
init
;
}
if
(
threadIdx
.
x
==
0
)
{
row_buf
[
0
]
=
op
(
row_buf
[
0
],
block_total
);
}
}
__syncthreads
();
for
(
size_t
s
=
kThreadNumX
,
d
=
1
;
s
>=
1
;
s
>>=
1
,
d
<<=
1
)
{
if
(
row
<
num_rows
&&
threadIdx
.
x
<
s
)
{
size_t
offset
=
(
2
*
threadIdx
.
x
+
1
)
*
d
-
1
;
row_buf
[
offset
+
d
]
=
op
(
row_buf
[
offset
],
row_buf
[
offset
+
d
]);
}
__syncthreads
();
}
for
(
size_t
s
=
2
,
d
=
kThreadNumX
/
2
;
d
>=
1
;
s
<<=
1
,
d
>>=
1
)
{
if
(
row
<
num_rows
&&
threadIdx
.
x
<
s
-
1
)
{
size_t
offset
=
2
*
(
threadIdx
.
x
+
1
)
*
d
-
1
;
row_buf
[
offset
+
d
]
=
op
(
row_buf
[
offset
],
row_buf
[
offset
+
d
]);
}
__syncthreads
();
}
if
(
row
<
num_rows
)
{
if
(
col1
<
row_size
)
row_y
[
col1
]
=
row_buf
[
threadIdx
.
x
];
if
(
col2
<
row_size
)
row_y
[
col2
]
=
row_buf
[
kThreadNumX
+
threadIdx
.
x
];
}
block_total
=
row_buf
[
2
*
kThreadNumX
-
1
];
__syncthreads
();
}
}
}
template
<
typename
T
,
typename
BinaryOp
>
static
void
InclusiveScanInnerDim
(
const
T
*
x
,
T
*
y
,
size_t
outer_dim
,
size_t
inner_dim
,
T
init
,
BinaryOp
op
,
bool
reverse
,
const
phi
::
GPUContext
&
dev_ctx
)
{
constexpr
size_t
kThreadNumX
=
16
;
constexpr
size_t
kThreadNumY
=
32
;
size_t
grid_dim
=
(
outer_dim
+
kThreadNumY
-
1
)
/
kThreadNumY
;
grid_dim
=
std
::
min
<
size_t
>
(
grid_dim
,
dev_ctx
.
GetCUDAMaxGridDimSize
()[
0
]);
dim3
thread_dims
(
kThreadNumX
,
kThreadNumY
);
if
(
reverse
)
{
InclusiveScanInnerDimCUDAKernel
<
T
,
BinaryOp
,
kThreadNumX
,
kThreadNumY
,
/*kReverse=*/
true
><<<
grid_dim
,
thread_dims
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
y
,
outer_dim
,
inner_dim
,
init
,
op
);
}
else
{
InclusiveScanInnerDimCUDAKernel
<
T
,
BinaryOp
,
kThreadNumX
,
kThreadNumY
,
/*kReverse=*/
false
><<<
grid_dim
,
thread_dims
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
y
,
outer_dim
,
inner_dim
,
init
,
op
);
}
}
template
<
typename
T
,
typename
BinaryOp
>
void
InclusiveScan
(
const
T
*
x
,
T
*
y
,
size_t
outer_dim
,
size_t
mid_dim
,
size_t
inner_dim
,
T
init
,
BinaryOp
op
,
bool
reverse
,
const
phi
::
GPUContext
&
dev_ctx
)
{
if
(
outer_dim
==
0
||
mid_dim
==
0
||
inner_dim
==
0
)
return
;
if
(
outer_dim
==
1
&&
inner_dim
==
1
)
{
if
(
reverse
)
{
auto
x_reverse_iter
=
MakeThrustReverseIterator
(
x
+
mid_dim
);
auto
y_reverse_iter
=
MakeThrustReverseIterator
(
y
+
mid_dim
);
CubInclusiveScan
(
x_reverse_iter
,
y_reverse_iter
,
mid_dim
,
op
,
dev_ctx
);
}
else
{
CubInclusiveScan
(
x
,
y
,
mid_dim
,
op
,
dev_ctx
);
}
}
else
if
(
inner_dim
!=
1
)
{
phi
::
funcs
::
ForRange
<
phi
::
GPUContext
>
for_range
(
dev_ctx
,
outer_dim
*
inner_dim
);
if
(
reverse
)
{
for_range
(
InclusiveScanOuterOrMidDimFunctor
<
T
,
BinaryOp
,
/*kReverse=*/
true
>
(
x
,
y
,
mid_dim
,
inner_dim
,
init
,
op
));
}
else
{
for_range
(
InclusiveScanOuterOrMidDimFunctor
<
T
,
BinaryOp
,
/*kReverse=*/
false
>
(
x
,
y
,
mid_dim
,
inner_dim
,
init
,
op
));
}
}
else
{
InclusiveScanInnerDim
<
T
,
BinaryOp
>
(
x
,
y
,
outer_dim
,
mid_dim
,
init
,
op
,
reverse
,
dev_ctx
);
}
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/gpu/multinomial_kernel.cu
浏览文件 @
58970995
...
...
@@ -23,11 +23,32 @@ limitations under the License. */
#include <thrust/scan.h>
#include <thrust/transform.h>
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
#include "paddle/phi/kernels/funcs/multinomial_functor.h"
#include "paddle/phi/kernels/top_k_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
...
...
@@ -57,12 +78,12 @@ template <typename T>
__global__
void
GetCumulativeProbs
(
T
*
norm_probs_data
,
int64_t
num_distributions
,
int64_t
num_categories
,
T
*
cumulative_probs
)
{
T
*
cumulative_probs
_data
)
{
int
id
=
blockIdx
.
x
;
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
(
id
+
1
)
*
num_categories
,
cumulative_probs
+
id
*
num_categories
);
cumulative_probs
_data
+
id
*
num_categories
);
}
template
<
typename
T
>
...
...
@@ -80,7 +101,7 @@ struct RandomGeneratorCudaFunctor {
};
template
<
typename
T
>
__device__
int
binarySearchFunctor
(
T
*
cumulative_probs
,
__device__
int
binarySearchFunctor
(
T
*
cumulative_probs
_data
,
T
*
norm_probs_data
,
int
num_categories
,
T
rng_number
)
{
...
...
@@ -90,7 +111,7 @@ __device__ int binarySearchFunctor(T* cumulative_probs,
while
(
right
-
left
>
0
)
{
int
mid
=
left
+
(
right
-
left
)
/
2
;
T
temp_prob
=
cumulative_probs
[
mid
];
T
temp_prob
=
cumulative_probs
_data
[
mid
];
if
(
temp_prob
<
rng_number
)
{
left
=
mid
+
1
;
}
else
{
...
...
@@ -114,27 +135,36 @@ __global__ void sampleMultinomialWithReplacement(
int64_t
*
out_data
,
const
int64_t
num_distributions
,
const
int64_t
num_categories
,
T
*
cumulative_probs
,
T
*
norm_probs_data
)
{
T
*
cumulative_probs_data
,
T
*
norm_probs_data
,
uint64_t
seed
,
uint64_t
offset
,
bool
use_curand
)
{
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
// let cumulative_probs_data[id-1] < rng_data < cumulative_probs_data[id].
size_t
idx
=
gridDim
.
x
*
blockDim
.
x
*
blockIdx
.
y
+
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
offset
,
&
state
);
// for every distribution
int
dist
=
blockIdx
.
y
;
// for every sample
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
dist
=
blockIdx
.
y
;
dist
<
num_distributions
;
dist
+=
gridDim
.
y
)
{
if
(
sample
<
num_samples
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
if
(
use_curand
)
{
rng_number
=
static_cast
<
T
>
(
curand_uniform4
(
&
state
).
x
);
}
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs
+
dist
*
num_categories
,
binarySearchFunctor
<
T
>
(
cumulative_probs_data
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
num_categories
,
rng_number
);
out_data
[
sample
+
dist
*
num_samples
]
=
selected_category
;
}
}
}
template
<
typename
T
,
typename
Context
>
...
...
@@ -172,6 +202,54 @@ void MultinomialKernel(const Context& dev_ctx,
in_data_numel
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
#endif
if
(
FLAGS_use_curand
)
{
for
(
size_t
i
=
0
;
i
<
num_distributions
;
++
i
)
{
int
zero_num
=
0
;
for
(
size_t
j
=
0
;
j
<
num_categories
;
++
j
)
{
T
weight
=
cpu_in_data
[
i
*
num_distributions
+
j
];
PADDLE_ENFORCE_GE
(
weight
,
0
,
errors
::
InvalidArgument
(
"Each element of multinomial'input must >= 0, but got %f."
,
weight
));
if
(
weight
==
static_cast
<
T
>
(
0
))
{
zero_num
++
;
}
}
int
valid_samples
=
num_categories
-
zero_num
;
PADDLE_ENFORCE_LE
(
num_samples
,
valid_samples
,
errors
::
InvalidArgument
(
"When replacement=False, 'num_samples' "
"must less than or eaqual to the number of "
"positive item of input"
));
}
// Refer to [gumbel softmax algorithm]
DenseTensor
rand
=
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
);
T
*
rand_data
=
rand
.
data
<
T
>
();
funcs
::
uniform_distribution
<
T
>
dist
;
funcs
::
exponential_transform
<
T
>
trans
(
1.0
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
&
rand
,
dist
,
trans
);
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
x
.
numel
());
for_range
([
rand_data
,
in_data
]
__device__
(
size_t
idx
)
{
rand_data
[
idx
]
=
in_data
[
idx
]
/
rand_data
[
idx
];
});
if
(
num_samples
==
1
)
{
ArgMaxKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
-
1
,
true
,
false
,
3
/*proto::VarType::INT64*/
,
out
);
}
else
{
std
::
vector
<
int64_t
>
out_dim_vec
=
vectorize
<
int64_t
>
(
out
->
dims
());
DenseTensor
value
=
Empty
<
T
,
Context
>
(
dev_ctx
,
ScalarArray
(
out_dim_vec
));
TopkKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
Scalar
(
num_samples
),
-
1
,
true
,
true
,
&
value
,
out
);
}
return
;
}
funcs
::
MultinomialFunctor
<
T
>
(
dev_ctx
,
cpu_out_data
,
...
...
@@ -228,7 +306,8 @@ void MultinomialKernel(const Context& dev_ctx,
auto
*
norm_probs_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
norm_probs_tensor
);
// number of threads in a block is min(num_categories, 512)
dim3
block_norm
(
num_categories
<
512
?
num_categories
:
512
);
int
block_size
=
num_categories
<
512
?
num_categories
:
512
;
dim3
block_norm
(
block_size
);
dim3
grid_norm
((
num_distributions
*
num_categories
-
1
)
/
block_norm
.
x
+
1
);
NormalizeProbability
<
T
><<<
grid_norm
,
block_norm
,
0
,
dev_ctx
.
stream
()
>>>
(
norm_probs_data
,
...
...
@@ -238,16 +317,34 @@ void MultinomialKernel(const Context& dev_ctx,
num_categories
);
// Get cumulative probability of each distribution. It's the same function
// of
// ``cumsum`` op.
// of ``cumsum`` op.
DenseTensor
cumulative_probs_tensor
;
cumulative_probs_tensor
.
Resize
({
num_distributions
,
num_categories
});
auto
*
cumulative_probs
=
dev_ctx
.
template
Alloc
<
T
>(
&
cumulative_probs_tensor
);
auto
*
cumulative_probs_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
cumulative_probs_tensor
);
if
(
FLAGS_use_curand
)
{
// 'phi::funcs::InclusiveScan' has higher accuracy than
// 'thrust::inclusive_scan'
funcs
::
InclusiveScan
<
T
,
std
::
plus
<
T
>>
(
/*in*/
norm_probs_data
,
/*out*/
cumulative_probs_data
,
/*outer_dim*/
static_cast
<
size_t
>
(
num_distributions
),
/*mid_dim*/
static_cast
<
size_t
>
(
num_categories
),
/*inner_dim*/
static_cast
<
size_t
>
(
1
),
/*init*/
static_cast
<
T
>
(
0
),
std
::
plus
<
T
>
(),
/*reverse=*/
false
,
dev_ctx
);
}
else
{
dim3
block_cumsum
(
1
);
dim3
grid_cumsum
(
num_distributions
);
GetCumulativeProbs
<
T
><<<
grid_cumsum
,
block_cumsum
,
0
,
dev_ctx
.
stream
()
>>>
(
norm_probs_data
,
num_distributions
,
num_categories
,
cumulative_probs
);
norm_probs_data
,
num_distributions
,
num_categories
,
cumulative_probs_data
);
}
// Generate random number for each sample.
std
::
random_device
rd
;
...
...
@@ -266,16 +363,30 @@ void MultinomialKernel(const Context& dev_ctx,
RandomGeneratorCudaFunctor
<
T
>
(
seed
));
// Sample the multinomial distributions.
dim3
block_sample
(
128
);
dim3
grid_sample
((
num_samples
-
1
)
/
block_sample
.
x
+
1
,
num_distributions
);
sampleMultinomialWithReplacement
<
T
><<<
grid_sample
,
block_sample
,
0
,
dev_ctx
.
stream
()
>>>
(
rng_data
,
dim3
block
(
128
);
int64_t
device_id
=
dev_ctx
.
GetPlace
().
GetDeviceId
();
const
auto
&
prop
=
phi
::
backends
::
gpu
::
GetDeviceProperties
(
device_id
);
int
grid_y
=
std
::
min
<
int64_t
>
(
num_distributions
,
prop
.
maxGridSize
[
1
]);
dim3
grid
((
num_samples
-
1
)
/
block
.
x
+
1
,
grid_y
);
auto
gen_cuda
=
dev_ctx
.
GetGenerator
();
size_t
curand4_loop_times
=
(
num_distributions
+
4
*
grid_y
-
1
)
/
(
4
*
grid_y
);
// 'increment' shoulde be multiple of 4
uint64_t
increment
=
curand4_loop_times
*
4
;
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
increment
);
sampleMultinomialWithReplacement
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
rng_data
,
num_samples
,
out_data
,
num_distributions
,
num_categories
,
cumulative_probs
,
norm_probs_data
);
cumulative_probs_data
,
norm_probs_data
,
seed_offset
.
first
,
seed_offset
.
second
,
FLAGS_use_curand
);
}
}
// namespace phi
...
...
python/paddle/fluid/tests/unittests/test_multinomial_op.py
浏览文件 @
58970995
...
...
@@ -20,6 +20,7 @@ import paddle.fluid as fluid
from
paddle.fluid
import
core
from
op_test
import
OpTest
import
numpy
as
np
import
os
def
sample_output_one_dimension
(
out
,
dim
):
...
...
@@ -216,5 +217,59 @@ class TestMultinomialError(unittest.TestCase):
self
.
assertRaises
(
ValueError
,
test_dim_less_than_1
)
class
TestRandomValue
(
unittest
.
TestCase
):
def
test_fixed_random_number
(
self
):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if
not
paddle
.
is_compiled_with_cuda
():
return
# Different GPU generatte different random value. Only test V100 here.
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
paddle
.
seed
(
100
)
x
=
paddle
.
randint
(
0
,
100
,
[
1024
,
10000
]).
astype
(
'float32'
)
y
=
paddle
.
multinomial
(
x
,
1
,
replacement
=
False
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
5187793
)
self
.
assertEqual
(
np
.
mean
(
y
),
5066.2041015625
)
expect
=
[
9982
,
1655
,
4741
,
1323
,
9319
,
3298
,
6473
,
7477
,
2507
,
2628
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
:
110
,
:].
flatten
(),
expect
))
y
=
paddle
.
multinomial
(
x
,
5000
,
replacement
=
False
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
25603962316
)
self
.
assertEqual
(
np
.
mean
(
y
),
5000.77388984375
)
expect
=
[
7300
,
6055
,
8714
,
5401
,
7360
,
161
,
5035
,
7002
,
6788
,
2916
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
,
1000
:
1010
],
expect
))
y
=
paddle
.
multinomial
(
x
,
5000
,
replacement
=
False
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
25592855710
)
self
.
assertEqual
(
np
.
mean
(
y
),
4998.604630859375
)
expect
=
[
5700
,
6567
,
4399
,
5688
,
7472
,
545
,
6894
,
526
,
2124
,
385
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
300
,
3000
:
3010
],
expect
))
y
=
paddle
.
multinomial
(
x
,
20000
,
replacement
=
True
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
102371362581
)
self
.
assertEqual
(
np
.
mean
(
y
),
4998.60168852539
)
self
.
assertEqual
(
np
.
std
(
y
),
2886.316308500771
)
expect
=
[
7630
,
8235
,
8445
,
3275
,
5580
,
4591
,
1331
,
342
,
1662
,
7156
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
,
0
:
10
],
expect
))
y
=
paddle
.
multinomial
(
x
,
20000
,
replacement
=
True
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
102400672117
)
self
.
assertEqual
(
np
.
mean
(
y
),
5000.032818212891
)
self
.
assertEqual
(
np
.
std
(
y
),
2886.913426124017
)
expect
=
[
4159
,
7849
,
9305
,
5759
,
4422
,
122
,
345
,
2897
,
5200
,
5911
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
,
0
:
10
],
expect
))
paddle
.
enable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录