Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
58970995
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
58970995
编写于
3月 23, 2022
作者:
zhouweiwei2014
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change CUDA implementation of multinomial OP (#40752)
上级
95d3ebc8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
488 addition
and
44 deletion
+488
-44
paddle/phi/kernels/funcs/distribution_helper.h
paddle/phi/kernels/funcs/distribution_helper.h
+8
-4
paddle/phi/kernels/funcs/inclusive_scan.h
paddle/phi/kernels/funcs/inclusive_scan.h
+274
-0
paddle/phi/kernels/gpu/multinomial_kernel.cu
paddle/phi/kernels/gpu/multinomial_kernel.cu
+151
-40
python/paddle/fluid/tests/unittests/test_multinomial_op.py
python/paddle/fluid/tests/unittests/test_multinomial_op.py
+55
-0
未找到文件。
paddle/phi/kernels/funcs/distribution_helper.h
浏览文件 @
58970995
...
...
@@ -50,11 +50,15 @@ struct exponential_transform {
HOSTDEVICE
inline
T
operator
()(
T
val
)
const
{
#if defined(__NVCC__) || defined(__HIPCC__)
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
log
(
val
);
}
else
{
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
__logf
(
val
);
T
log
=
-
std
::
numeric_limits
<
T
>::
epsilon
()
/
2
;
if
(
val
<
static_cast
<
T
>
(
1.
)
-
std
::
numeric_limits
<
T
>::
epsilon
()
/
2
)
{
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
log
=
logf
(
val
);
}
else
{
log
=
__logf
(
val
);
}
}
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
log
;
#else
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
std
::
log
(
static_cast
<
T
>
(
1.0
)
-
val
);
#endif
...
...
paddle/phi/kernels/funcs/inclusive_scan.h
0 → 100644
浏览文件 @
58970995
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include <thrust/device_ptr.h>
#include <thrust/iterator/reverse_iterator.h>
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/malloc.h"
namespace
phi
{
namespace
funcs
{
template
<
typename
T
>
struct
IsComplex
:
public
std
::
false_type
{};
template
<
>
struct
IsComplex
<::
phi
::
dtype
::
complex
<
float
>>
:
public
std
::
true_type
{};
template
<
>
struct
IsComplex
<::
phi
::
dtype
::
complex
<
double
>>
:
public
std
::
true_type
{};
template
<
typename
InputIterator
,
typename
OutputIterator
,
typename
BinaryOp
>
static
void
CubInclusiveScan
(
InputIterator
x_iter
,
OutputIterator
y_iter
,
size_t
n
,
BinaryOp
op
,
const
phi
::
GPUContext
&
dev_ctx
)
{
paddle
::
memory
::
allocation
::
AllocationPtr
allocation
;
void
*
temp_storage
=
nullptr
;
size_t
temp_storage_bytes
=
0
;
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceScan
::
InclusiveScan
(
temp_storage
,
temp_storage_bytes
,
x_iter
,
y_iter
,
op
,
static_cast
<
int
>
(
n
),
dev_ctx
.
stream
()));
if
(
i
==
0
&&
temp_storage_bytes
>
0
)
{
allocation
=
paddle
::
memory
::
Alloc
(
dev_ctx
.
GetPlace
(),
temp_storage_bytes
);
temp_storage
=
allocation
->
ptr
();
}
}
}
template
<
typename
T
>
static
auto
MakeThrustReverseIterator
(
T
*
x
)
{
return
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
T
>>
(
thrust
::
device_pointer_cast
(
x
));
}
template
<
typename
T
,
typename
BinaryOp
,
bool
kReverse
>
struct
InclusiveScanOuterOrMidDimFunctor
{
HOSTDEVICE
InclusiveScanOuterOrMidDimFunctor
(
const
T
*
x
,
T
*
y
,
size_t
mid_dim
,
size_t
inner_dim
,
T
init
,
BinaryOp
op
)
:
x_
(
x
),
y_
(
y
),
mid_dim_
(
mid_dim
),
inner_dim_
(
inner_dim
),
init_
(
init
),
op_
(
op
)
{}
HOSTDEVICE
void
operator
()(
size_t
idx
)
const
{
auto
outer_idx
=
idx
/
inner_dim_
;
auto
inner_idx
=
idx
%
inner_dim_
;
if
(
kReverse
)
{
idx
=
outer_idx
*
mid_dim_
*
inner_dim_
+
(
mid_dim_
-
1
)
*
inner_dim_
+
inner_idx
;
}
else
{
idx
=
outer_idx
*
mid_dim_
*
inner_dim_
+
inner_idx
;
}
auto
x_ptr
=
x_
+
idx
;
auto
y_ptr
=
y_
+
idx
;
T
acc_value
=
init_
;
for
(
size_t
i
=
0
;
i
<
mid_dim_
;
++
i
)
{
acc_value
=
op_
(
acc_value
,
*
x_ptr
);
*
y_ptr
=
acc_value
;
if
(
kReverse
)
{
x_ptr
-=
inner_dim_
;
y_ptr
-=
inner_dim_
;
}
else
{
x_ptr
+=
inner_dim_
;
y_ptr
+=
inner_dim_
;
}
}
}
private:
const
T
*
x_
;
T
*
y_
;
size_t
mid_dim_
;
size_t
inner_dim_
;
T
init_
;
BinaryOp
op_
;
};
template
<
typename
T
,
typename
BinaryOp
,
size_t
kThreadNumX
,
size_t
kThreadNumY
,
bool
kReverse
>
static
__global__
void
InclusiveScanInnerDimCUDAKernel
(
const
T
*
x
,
T
*
y
,
size_t
num_rows
,
size_t
row_size
,
T
init
,
BinaryOp
op
)
{
using
RealT
=
phi
::
dtype
::
Real
<
T
>
;
constexpr
auto
kSharedBufferSize
=
IsComplex
<
T
>::
value
?
4
*
kThreadNumX
:
2
*
kThreadNumX
;
__shared__
RealT
sbuf
[
kThreadNumY
][
kSharedBufferSize
];
T
*
row_buf
=
reinterpret_cast
<
T
*>
(
sbuf
[
threadIdx
.
y
]);
size_t
block_row
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
kThreadNumY
);
size_t
block_row_stride
=
static_cast
<
size_t
>
(
gridDim
.
x
*
kThreadNumY
);
for
(;
block_row
<
num_rows
;
block_row
+=
block_row_stride
)
{
size_t
row
=
block_row
+
threadIdx
.
y
;
T
block_total
=
init
;
const
T
*
row_x
=
x
+
row
*
row_size
;
T
*
row_y
=
y
+
row
*
row_size
;
for
(
size_t
block_col
=
0
;
block_col
<
row_size
;
block_col
+=
2
*
kThreadNumX
)
{
size_t
col1
,
col2
;
if
(
kReverse
)
{
col1
=
row_size
-
1
-
block_col
-
threadIdx
.
x
;
col2
=
col1
-
kThreadNumX
;
}
else
{
col1
=
block_col
+
threadIdx
.
x
;
col2
=
col1
+
kThreadNumX
;
}
if
(
row
<
num_rows
)
{
if
(
col1
<
row_size
)
{
row_buf
[
threadIdx
.
x
]
=
row_x
[
col1
];
}
else
{
row_buf
[
threadIdx
.
x
]
=
init
;
}
if
(
col2
<
row_size
)
{
row_buf
[
kThreadNumX
+
threadIdx
.
x
]
=
row_x
[
col2
];
}
else
{
row_buf
[
kThreadNumX
+
threadIdx
.
x
]
=
init
;
}
if
(
threadIdx
.
x
==
0
)
{
row_buf
[
0
]
=
op
(
row_buf
[
0
],
block_total
);
}
}
__syncthreads
();
for
(
size_t
s
=
kThreadNumX
,
d
=
1
;
s
>=
1
;
s
>>=
1
,
d
<<=
1
)
{
if
(
row
<
num_rows
&&
threadIdx
.
x
<
s
)
{
size_t
offset
=
(
2
*
threadIdx
.
x
+
1
)
*
d
-
1
;
row_buf
[
offset
+
d
]
=
op
(
row_buf
[
offset
],
row_buf
[
offset
+
d
]);
}
__syncthreads
();
}
for
(
size_t
s
=
2
,
d
=
kThreadNumX
/
2
;
d
>=
1
;
s
<<=
1
,
d
>>=
1
)
{
if
(
row
<
num_rows
&&
threadIdx
.
x
<
s
-
1
)
{
size_t
offset
=
2
*
(
threadIdx
.
x
+
1
)
*
d
-
1
;
row_buf
[
offset
+
d
]
=
op
(
row_buf
[
offset
],
row_buf
[
offset
+
d
]);
}
__syncthreads
();
}
if
(
row
<
num_rows
)
{
if
(
col1
<
row_size
)
row_y
[
col1
]
=
row_buf
[
threadIdx
.
x
];
if
(
col2
<
row_size
)
row_y
[
col2
]
=
row_buf
[
kThreadNumX
+
threadIdx
.
x
];
}
block_total
=
row_buf
[
2
*
kThreadNumX
-
1
];
__syncthreads
();
}
}
}
template
<
typename
T
,
typename
BinaryOp
>
static
void
InclusiveScanInnerDim
(
const
T
*
x
,
T
*
y
,
size_t
outer_dim
,
size_t
inner_dim
,
T
init
,
BinaryOp
op
,
bool
reverse
,
const
phi
::
GPUContext
&
dev_ctx
)
{
constexpr
size_t
kThreadNumX
=
16
;
constexpr
size_t
kThreadNumY
=
32
;
size_t
grid_dim
=
(
outer_dim
+
kThreadNumY
-
1
)
/
kThreadNumY
;
grid_dim
=
std
::
min
<
size_t
>
(
grid_dim
,
dev_ctx
.
GetCUDAMaxGridDimSize
()[
0
]);
dim3
thread_dims
(
kThreadNumX
,
kThreadNumY
);
if
(
reverse
)
{
InclusiveScanInnerDimCUDAKernel
<
T
,
BinaryOp
,
kThreadNumX
,
kThreadNumY
,
/*kReverse=*/
true
><<<
grid_dim
,
thread_dims
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
y
,
outer_dim
,
inner_dim
,
init
,
op
);
}
else
{
InclusiveScanInnerDimCUDAKernel
<
T
,
BinaryOp
,
kThreadNumX
,
kThreadNumY
,
/*kReverse=*/
false
><<<
grid_dim
,
thread_dims
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
y
,
outer_dim
,
inner_dim
,
init
,
op
);
}
}
template
<
typename
T
,
typename
BinaryOp
>
void
InclusiveScan
(
const
T
*
x
,
T
*
y
,
size_t
outer_dim
,
size_t
mid_dim
,
size_t
inner_dim
,
T
init
,
BinaryOp
op
,
bool
reverse
,
const
phi
::
GPUContext
&
dev_ctx
)
{
if
(
outer_dim
==
0
||
mid_dim
==
0
||
inner_dim
==
0
)
return
;
if
(
outer_dim
==
1
&&
inner_dim
==
1
)
{
if
(
reverse
)
{
auto
x_reverse_iter
=
MakeThrustReverseIterator
(
x
+
mid_dim
);
auto
y_reverse_iter
=
MakeThrustReverseIterator
(
y
+
mid_dim
);
CubInclusiveScan
(
x_reverse_iter
,
y_reverse_iter
,
mid_dim
,
op
,
dev_ctx
);
}
else
{
CubInclusiveScan
(
x
,
y
,
mid_dim
,
op
,
dev_ctx
);
}
}
else
if
(
inner_dim
!=
1
)
{
phi
::
funcs
::
ForRange
<
phi
::
GPUContext
>
for_range
(
dev_ctx
,
outer_dim
*
inner_dim
);
if
(
reverse
)
{
for_range
(
InclusiveScanOuterOrMidDimFunctor
<
T
,
BinaryOp
,
/*kReverse=*/
true
>
(
x
,
y
,
mid_dim
,
inner_dim
,
init
,
op
));
}
else
{
for_range
(
InclusiveScanOuterOrMidDimFunctor
<
T
,
BinaryOp
,
/*kReverse=*/
false
>
(
x
,
y
,
mid_dim
,
inner_dim
,
init
,
op
));
}
}
else
{
InclusiveScanInnerDim
<
T
,
BinaryOp
>
(
x
,
y
,
outer_dim
,
mid_dim
,
init
,
op
,
reverse
,
dev_ctx
);
}
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/gpu/multinomial_kernel.cu
浏览文件 @
58970995
...
...
@@ -23,11 +23,32 @@ limitations under the License. */
#include <thrust/scan.h>
#include <thrust/transform.h>
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
#include "paddle/phi/kernels/funcs/multinomial_functor.h"
#include "paddle/phi/kernels/top_k_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
...
...
@@ -57,12 +78,12 @@ template <typename T>
__global__
void
GetCumulativeProbs
(
T
*
norm_probs_data
,
int64_t
num_distributions
,
int64_t
num_categories
,
T
*
cumulative_probs
)
{
T
*
cumulative_probs
_data
)
{
int
id
=
blockIdx
.
x
;
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
(
id
+
1
)
*
num_categories
,
cumulative_probs
+
id
*
num_categories
);
cumulative_probs
_data
+
id
*
num_categories
);
}
template
<
typename
T
>
...
...
@@ -80,7 +101,7 @@ struct RandomGeneratorCudaFunctor {
};
template
<
typename
T
>
__device__
int
binarySearchFunctor
(
T
*
cumulative_probs
,
__device__
int
binarySearchFunctor
(
T
*
cumulative_probs
_data
,
T
*
norm_probs_data
,
int
num_categories
,
T
rng_number
)
{
...
...
@@ -90,7 +111,7 @@ __device__ int binarySearchFunctor(T* cumulative_probs,
while
(
right
-
left
>
0
)
{
int
mid
=
left
+
(
right
-
left
)
/
2
;
T
temp_prob
=
cumulative_probs
[
mid
];
T
temp_prob
=
cumulative_probs
_data
[
mid
];
if
(
temp_prob
<
rng_number
)
{
left
=
mid
+
1
;
}
else
{
...
...
@@ -114,26 +135,35 @@ __global__ void sampleMultinomialWithReplacement(
int64_t
*
out_data
,
const
int64_t
num_distributions
,
const
int64_t
num_categories
,
T
*
cumulative_probs
,
T
*
norm_probs_data
)
{
T
*
cumulative_probs_data
,
T
*
norm_probs_data
,
uint64_t
seed
,
uint64_t
offset
,
bool
use_curand
)
{
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
// let cumulative_probs_data[id-1] < rng_data < cumulative_probs_data[id].
size_t
idx
=
gridDim
.
x
*
blockDim
.
x
*
blockIdx
.
y
+
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// for every distribution
int
dist
=
blockIdx
.
y
;
// for every sample
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
sample
<
num_samples
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
offset
,
&
state
);
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
num_categories
,
rng_number
);
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
dist
=
blockIdx
.
y
;
dist
<
num_distributions
;
dist
+=
gridDim
.
y
)
{
if
(
sample
<
num_samples
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
if
(
use_curand
)
{
rng_number
=
static_cast
<
T
>
(
curand_uniform4
(
&
state
).
x
);
}
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs_data
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
num_categories
,
rng_number
);
out_data
[
sample
+
dist
*
num_samples
]
=
selected_category
;
out_data
[
sample
+
dist
*
num_samples
]
=
selected_category
;
}
}
}
...
...
@@ -172,6 +202,54 @@ void MultinomialKernel(const Context& dev_ctx,
in_data_numel
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
#endif
if
(
FLAGS_use_curand
)
{
for
(
size_t
i
=
0
;
i
<
num_distributions
;
++
i
)
{
int
zero_num
=
0
;
for
(
size_t
j
=
0
;
j
<
num_categories
;
++
j
)
{
T
weight
=
cpu_in_data
[
i
*
num_distributions
+
j
];
PADDLE_ENFORCE_GE
(
weight
,
0
,
errors
::
InvalidArgument
(
"Each element of multinomial'input must >= 0, but got %f."
,
weight
));
if
(
weight
==
static_cast
<
T
>
(
0
))
{
zero_num
++
;
}
}
int
valid_samples
=
num_categories
-
zero_num
;
PADDLE_ENFORCE_LE
(
num_samples
,
valid_samples
,
errors
::
InvalidArgument
(
"When replacement=False, 'num_samples' "
"must less than or eaqual to the number of "
"positive item of input"
));
}
// Refer to [gumbel softmax algorithm]
DenseTensor
rand
=
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
);
T
*
rand_data
=
rand
.
data
<
T
>
();
funcs
::
uniform_distribution
<
T
>
dist
;
funcs
::
exponential_transform
<
T
>
trans
(
1.0
);
funcs
::
distribution_and_transform
<
T
>
(
dev_ctx
,
&
rand
,
dist
,
trans
);
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
x
.
numel
());
for_range
([
rand_data
,
in_data
]
__device__
(
size_t
idx
)
{
rand_data
[
idx
]
=
in_data
[
idx
]
/
rand_data
[
idx
];
});
if
(
num_samples
==
1
)
{
ArgMaxKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
-
1
,
true
,
false
,
3
/*proto::VarType::INT64*/
,
out
);
}
else
{
std
::
vector
<
int64_t
>
out_dim_vec
=
vectorize
<
int64_t
>
(
out
->
dims
());
DenseTensor
value
=
Empty
<
T
,
Context
>
(
dev_ctx
,
ScalarArray
(
out_dim_vec
));
TopkKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
Scalar
(
num_samples
),
-
1
,
true
,
true
,
&
value
,
out
);
}
return
;
}
funcs
::
MultinomialFunctor
<
T
>
(
dev_ctx
,
cpu_out_data
,
...
...
@@ -228,7 +306,8 @@ void MultinomialKernel(const Context& dev_ctx,
auto
*
norm_probs_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
norm_probs_tensor
);
// number of threads in a block is min(num_categories, 512)
dim3
block_norm
(
num_categories
<
512
?
num_categories
:
512
);
int
block_size
=
num_categories
<
512
?
num_categories
:
512
;
dim3
block_norm
(
block_size
);
dim3
grid_norm
((
num_distributions
*
num_categories
-
1
)
/
block_norm
.
x
+
1
);
NormalizeProbability
<
T
><<<
grid_norm
,
block_norm
,
0
,
dev_ctx
.
stream
()
>>>
(
norm_probs_data
,
...
...
@@ -238,16 +317,34 @@ void MultinomialKernel(const Context& dev_ctx,
num_categories
);
// Get cumulative probability of each distribution. It's the same function
// of
// ``cumsum`` op.
// of ``cumsum`` op.
DenseTensor
cumulative_probs_tensor
;
cumulative_probs_tensor
.
Resize
({
num_distributions
,
num_categories
});
auto
*
cumulative_probs
=
dev_ctx
.
template
Alloc
<
T
>(
&
cumulative_probs_tensor
);
dim3
block_cumsum
(
1
);
dim3
grid_cumsum
(
num_distributions
);
GetCumulativeProbs
<
T
><<<
grid_cumsum
,
block_cumsum
,
0
,
dev_ctx
.
stream
()
>>>
(
norm_probs_data
,
num_distributions
,
num_categories
,
cumulative_probs
);
auto
*
cumulative_probs_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
cumulative_probs_tensor
);
if
(
FLAGS_use_curand
)
{
// 'phi::funcs::InclusiveScan' has higher accuracy than
// 'thrust::inclusive_scan'
funcs
::
InclusiveScan
<
T
,
std
::
plus
<
T
>>
(
/*in*/
norm_probs_data
,
/*out*/
cumulative_probs_data
,
/*outer_dim*/
static_cast
<
size_t
>
(
num_distributions
),
/*mid_dim*/
static_cast
<
size_t
>
(
num_categories
),
/*inner_dim*/
static_cast
<
size_t
>
(
1
),
/*init*/
static_cast
<
T
>
(
0
),
std
::
plus
<
T
>
(),
/*reverse=*/
false
,
dev_ctx
);
}
else
{
dim3
block_cumsum
(
1
);
dim3
grid_cumsum
(
num_distributions
);
GetCumulativeProbs
<
T
><<<
grid_cumsum
,
block_cumsum
,
0
,
dev_ctx
.
stream
()
>>>
(
norm_probs_data
,
num_distributions
,
num_categories
,
cumulative_probs_data
);
}
// Generate random number for each sample.
std
::
random_device
rd
;
...
...
@@ -266,16 +363,30 @@ void MultinomialKernel(const Context& dev_ctx,
RandomGeneratorCudaFunctor
<
T
>
(
seed
));
// Sample the multinomial distributions.
dim3
block_sample
(
128
);
dim3
grid_sample
((
num_samples
-
1
)
/
block_sample
.
x
+
1
,
num_distributions
);
sampleMultinomialWithReplacement
<
T
><<<
grid_sample
,
block_sample
,
0
,
dev_ctx
.
stream
()
>>>
(
rng_data
,
num_samples
,
out_data
,
num_distributions
,
num_categories
,
cumulative_probs
,
norm_probs_data
);
dim3
block
(
128
);
int64_t
device_id
=
dev_ctx
.
GetPlace
().
GetDeviceId
();
const
auto
&
prop
=
phi
::
backends
::
gpu
::
GetDeviceProperties
(
device_id
);
int
grid_y
=
std
::
min
<
int64_t
>
(
num_distributions
,
prop
.
maxGridSize
[
1
]);
dim3
grid
((
num_samples
-
1
)
/
block
.
x
+
1
,
grid_y
);
auto
gen_cuda
=
dev_ctx
.
GetGenerator
();
size_t
curand4_loop_times
=
(
num_distributions
+
4
*
grid_y
-
1
)
/
(
4
*
grid_y
);
// 'increment' shoulde be multiple of 4
uint64_t
increment
=
curand4_loop_times
*
4
;
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
increment
);
sampleMultinomialWithReplacement
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
rng_data
,
num_samples
,
out_data
,
num_distributions
,
num_categories
,
cumulative_probs_data
,
norm_probs_data
,
seed_offset
.
first
,
seed_offset
.
second
,
FLAGS_use_curand
);
}
}
// namespace phi
...
...
python/paddle/fluid/tests/unittests/test_multinomial_op.py
浏览文件 @
58970995
...
...
@@ -20,6 +20,7 @@ import paddle.fluid as fluid
from
paddle.fluid
import
core
from
op_test
import
OpTest
import
numpy
as
np
import
os
def
sample_output_one_dimension
(
out
,
dim
):
...
...
@@ -216,5 +217,59 @@ class TestMultinomialError(unittest.TestCase):
self
.
assertRaises
(
ValueError
,
test_dim_less_than_1
)
class
TestRandomValue
(
unittest
.
TestCase
):
def
test_fixed_random_number
(
self
):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if
not
paddle
.
is_compiled_with_cuda
():
return
# Different GPU generatte different random value. Only test V100 here.
if
not
"V100"
in
paddle
.
device
.
cuda
.
get_device_name
():
return
if
os
.
getenv
(
"FLAGS_use_curand"
,
None
)
in
(
'0'
,
'False'
,
None
):
return
print
(
"Test Fixed Random number on V100 GPU------>"
)
paddle
.
disable_static
()
paddle
.
set_device
(
'gpu'
)
paddle
.
seed
(
100
)
x
=
paddle
.
randint
(
0
,
100
,
[
1024
,
10000
]).
astype
(
'float32'
)
y
=
paddle
.
multinomial
(
x
,
1
,
replacement
=
False
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
5187793
)
self
.
assertEqual
(
np
.
mean
(
y
),
5066.2041015625
)
expect
=
[
9982
,
1655
,
4741
,
1323
,
9319
,
3298
,
6473
,
7477
,
2507
,
2628
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
:
110
,
:].
flatten
(),
expect
))
y
=
paddle
.
multinomial
(
x
,
5000
,
replacement
=
False
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
25603962316
)
self
.
assertEqual
(
np
.
mean
(
y
),
5000.77388984375
)
expect
=
[
7300
,
6055
,
8714
,
5401
,
7360
,
161
,
5035
,
7002
,
6788
,
2916
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
,
1000
:
1010
],
expect
))
y
=
paddle
.
multinomial
(
x
,
5000
,
replacement
=
False
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
25592855710
)
self
.
assertEqual
(
np
.
mean
(
y
),
4998.604630859375
)
expect
=
[
5700
,
6567
,
4399
,
5688
,
7472
,
545
,
6894
,
526
,
2124
,
385
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
300
,
3000
:
3010
],
expect
))
y
=
paddle
.
multinomial
(
x
,
20000
,
replacement
=
True
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
102371362581
)
self
.
assertEqual
(
np
.
mean
(
y
),
4998.60168852539
)
self
.
assertEqual
(
np
.
std
(
y
),
2886.316308500771
)
expect
=
[
7630
,
8235
,
8445
,
3275
,
5580
,
4591
,
1331
,
342
,
1662
,
7156
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
,
0
:
10
],
expect
))
y
=
paddle
.
multinomial
(
x
,
20000
,
replacement
=
True
).
numpy
()
self
.
assertEqual
(
np
.
sum
(
y
),
102400672117
)
self
.
assertEqual
(
np
.
mean
(
y
),
5000.032818212891
)
self
.
assertEqual
(
np
.
std
(
y
),
2886.913426124017
)
expect
=
[
4159
,
7849
,
9305
,
5759
,
4422
,
122
,
345
,
2897
,
5200
,
5911
]
self
.
assertTrue
(
np
.
array_equal
(
y
[
100
,
0
:
10
],
expect
))
paddle
.
enable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录