Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
80537a1d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
80537a1d
编写于
9月 28, 2020
作者:
P
pangyoki
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multinomial python api unittest
上级
c66eec75
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
107 addition
and
160 deletion
+107
-160
paddle/fluid/operators/multinomial_op.cu
paddle/fluid/operators/multinomial_op.cu
+23
-143
python/paddle/fluid/tests/unittests/test_multinomial_op.py
python/paddle/fluid/tests/unittests/test_multinomial_op.py
+43
-0
python/paddle/tensor/random.py
python/paddle/tensor/random.py
+41
-17
未找到文件。
paddle/fluid/operators/multinomial_op.cu
浏览文件 @
80537a1d
...
@@ -26,69 +26,17 @@ limitations under the License. */
...
@@ -26,69 +26,17 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
/*
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
*/
/*
template <class T>
__global__ void SumArrayCUDAKernel(T **in, T *out, size_t in_size) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
// T total(read_dst ? out[id] : static_cast<T>(0));
T total(static_cast<T>(0))
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[id];
}
}
out[id] = total;
id += blockDim.x * gridDim.x;
}*/
/*
template <typename T>
__global__ void NormalizeProbability(T* probs, int64_t rows, int64_t cols) {
extern __shared__ std::vector<T> sum_rows(rows);
T val;
for (int64_t i = blockId.x; i < rows; i += gridDim.x) {
T sum = static_cast<T>(0);
for (int64_t j = threadIdx.x; j < cols; j += blockDim.x) {
val = probs[i * cols + j];
sum += val;
}
}
}*/
template
<
typename
T
>
template
<
typename
T
>
__global__
void
NormalizeProbability
(
T
*
norm_probs
,
const
T
*
in_data
,
__global__
void
NormalizeProbability
(
T
*
norm_probs
,
const
T
*
in_data
,
T
*
sum_rows
)
{
T
*
sum_rows
)
{
// int id = blockIdx.x * blockDim.x + threadIdx.x;
// int id = threadIdx.x;
int
id
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
int
id
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
norm_probs
[
id
]
=
in_data
[
id
]
/
sum_rows
[
blockIdx
.
y
];
norm_probs
[
id
]
=
in_data
[
id
]
/
sum_rows
[
blockIdx
.
y
];
}
}
template
<
typename
T
>
__global__
void
yokiFunc
(
const
T
*
in_data
,
T
*
out
)
{
// int id = blockIdx.x * blockDim.x + threadIdx.x;
// int id = threadIdx.x;
int
id
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
out
[
id
]
=
in_data
[
id
];
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
Cumsum
(
T
*
norm_probs_data
,
int64_t
num_distributions
,
__global__
void
Cumsum
(
T
*
norm_probs_data
,
int64_t
num_distributions
,
int64_t
num_categories
,
T
*
cumulative_probs
)
{
int64_t
num_categories
,
T
*
cumulative_probs
)
{
// int id = blockIdx.x;
for
(
int
id
=
blockIdx
.
x
;
id
<
num_distributions
;
id
+=
gridDim
.
x
)
{
for
(
int
id
=
blockIdx
.
x
;
id
<
num_distributions
;
id
+=
gridDim
.
x
)
{
thrust
::
inclusive_scan
(
thrust
::
device
,
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
id
*
num_categories
,
...
@@ -111,52 +59,43 @@ struct RandomGeneratorCudaFunctor {
...
@@ -111,52 +59,43 @@ struct RandomGeneratorCudaFunctor {
}
}
};
};
/*
template
<
typename
T
>
template
<
typename
T
>
class MultinomialCudaFunctor(T* out_data, const T* in_data,
__device__
int
binarySearchFunctor
(
T
*
cumdist
,
T
*
dist
,
int
size
,
T
val
)
{
const int64_t num_samples, const bool replacement,
int
left
=
0
;
const int64_t num_categories,
int
right
=
size
;
const int64_t num_distributions) {
}*/
template
<
typename
T
>
__device__
int
binarySearchForMultinomial
(
T
*
cumdist
,
T
*
dist
,
int
size
,
T
val
)
{
int
start
=
0
;
int
end
=
size
;
// cumdist[size - 1] = 0 => all zero prob dist
// cumdist[size - 1] = 0 => all zero prob dist
// CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0));
// CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0));
while
(
end
-
star
t
>
0
)
{
while
(
right
-
lef
t
>
0
)
{
int
mid
=
start
+
(
end
-
star
t
)
/
2
;
int
mid
=
left
+
(
right
-
lef
t
)
/
2
;
T
midVal
=
cumdist
[
mid
];
T
midVal
=
cumdist
[
mid
];
if
(
midVal
<
val
)
{
if
(
midVal
<
val
)
{
star
t
=
mid
+
1
;
lef
t
=
mid
+
1
;
}
else
{
}
else
{
end
=
mid
;
right
=
mid
;
}
}
}
}
if
(
star
t
==
size
)
{
if
(
lef
t
==
size
)
{
// No probability mass or precision problems; just return the
// No probability mass or precision problems; just return the
// first non-zero element by setting
star
t to size-1 here,
// first non-zero element by setting
lef
t to size-1 here,
// the code below will move it to the last non-zero probability
// the code below will move it to the last non-zero probability
// this actually can happen when the random number is 1
// this actually can happen when the random number is 1
// (github pytorch issue #4858).
// (github pytorch issue #4858).
star
t
=
size
-
1
;
lef
t
=
size
-
1
;
}
}
while
(
start
>=
1
&&
dist
[
start
]
==
0
)
star
t
--
;
while
(
left
>=
1
&&
dist
[
left
]
==
0
)
lef
t
--
;
return
star
t
;
return
lef
t
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
sampleMultinomialWithReplacement
(
__global__
void
sampleMultinomialWithReplacement
(
T
*
rng
,
const
int64_t
totalSamples
,
T
*
dest
,
const
int64_t
distributions
,
T
*
rng_data
,
const
int64_t
num_samples
,
T
*
out_data
,
const
int64_t
categories
,
T
*
normDistPrefixSum
,
T
*
normDist
)
{
const
int64_t
num_distributions
,
const
int64_t
num_categories
,
T
*
cumulative_probs
,
T
*
norm_probs_data
)
{
// At the moment, each warp computes one sample value in the binary
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on.
// values and limit divergence though later on.
...
@@ -170,22 +109,23 @@ __global__ void sampleMultinomialWithReplacement(
...
@@ -170,22 +109,23 @@ __global__ void sampleMultinomialWithReplacement(
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
for
(
int
curDist
=
blockIdx
.
y
;
curDist
<
distributions
;
for
(
int
curDist
=
blockIdx
.
y
;
curDist
<
num_
distributions
;
curDist
+=
gridDim
.
y
)
{
curDist
+=
gridDim
.
y
)
{
for
(
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
sample
<
totalS
amples
;
sample
+=
blockDim
.
x
*
gridDim
.
x
)
{
sample
<
num_s
amples
;
sample
+=
blockDim
.
x
*
gridDim
.
x
)
{
// we are losing 3 out of 4 generated numbers but it's ok
// we are losing 3 out of 4 generated numbers but it's ok
// this kernel is not very efficient anyway
// this kernel is not very efficient anyway
// T uniform_random = dist(rng);
// T uniform_random = dist(rng);
T
uniform_random
=
rng
[
sample
+
curDist
*
totalS
amples
];
T
uniform_random
=
rng
_data
[
sample
+
curDist
*
num_s
amples
];
// Find the bucket that a uniform sample lies in
// Find the bucket that a uniform sample lies in
int
choice
=
binarySearchForMultinomial
<
T
>
(
int
choice
=
normDistPrefixSum
+
curDist
*
categories
,
binarySearchFunctor
<
T
>
(
cumulative_probs
+
curDist
*
num_categories
,
normDist
+
curDist
*
categories
,
categories
,
uniform_random
);
norm_probs_data
+
curDist
*
num_categories
,
num_categories
,
uniform_random
);
dest
[
sample
+
curDist
*
totalS
amples
]
=
choice
;
out_data
[
sample
+
curDist
*
num_s
amples
]
=
choice
;
}
}
}
}
}
}
...
@@ -198,14 +138,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
...
@@ -198,14 +138,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
const
auto
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
auto
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
// auto yokiout = ctx.Output<framework::Tensor>("yokiOut");
const
int64_t
num_samples
=
ctx
.
Attr
<
int
>
(
"num_samples"
);
const
int64_t
num_samples
=
ctx
.
Attr
<
int
>
(
"num_samples"
);
const
bool
replacement
=
ctx
.
Attr
<
bool
>
(
"replacement"
);
const
bool
replacement
=
ctx
.
Attr
<
bool
>
(
"replacement"
);
auto
*
in_data
=
x
->
data
<
T
>
();
auto
*
in_data
=
x
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// auto* yokiout_data = yokiout->mutable_data<T>(ctx.GetPlace());
auto
in_dims
=
x
->
dims
();
auto
in_dims
=
x
->
dims
();
int64_t
in_rank
=
in_dims
.
size
();
int64_t
in_rank
=
in_dims
.
size
();
...
@@ -215,10 +152,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
...
@@ -215,10 +152,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
if
(
!
replacement
)
{
if
(
!
replacement
)
{
int
in_data_numel
=
x
->
numel
();
int
in_data_numel
=
x
->
numel
();
int
out_data_numel
=
out
->
numel
();
int
out_data_numel
=
out
->
numel
();
// std::vector<T> cpu_in_data(in_data_numel);
// std::vector<T> cpu_out_data(out_data_numel);
// T cpu_in_data[in_data_numel];
// T cpu_out_data[out_data_numel];
T
*
cpu_in_data
=
new
T
[
in_data_numel
];
T
*
cpu_in_data
=
new
T
[
in_data_numel
];
T
*
cpu_out_data
=
new
T
[
out_data_numel
];
T
*
cpu_out_data
=
new
T
[
out_data_numel
];
...
@@ -226,10 +159,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
...
@@ -226,10 +159,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
cudaMemcpy
(
cpu_in_data
,
in_data
,
in_data_numel
*
sizeof
(
T
),
cudaMemcpy
(
cpu_in_data
,
in_data
,
in_data_numel
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
cudaMemcpyDeviceToHost
);
VLOG
(
3
)
<<
"Print cpu_in_data "
<<
cpu_in_data
[
0
]
<<
"
\n
"
;
VLOG
(
3
)
<<
"Print in_data_numel "
<<
in_data_numel
<<
"
\n
"
;
VLOG
(
3
)
<<
"Print out_data_numel "
<<
out_data_numel
<<
"
\n
"
;
MultinomialFunctor
<
T
>
(
cpu_out_data
,
cpu_in_data
,
num_samples
,
replacement
,
MultinomialFunctor
<
T
>
(
cpu_out_data
,
cpu_in_data
,
num_samples
,
replacement
,
num_categories
,
num_distributions
);
num_categories
,
num_distributions
);
cudaMemcpy
(
out_data
,
cpu_out_data
,
out_data_numel
*
sizeof
(
T
),
cudaMemcpy
(
out_data
,
cpu_out_data
,
out_data_numel
*
sizeof
(
T
),
...
@@ -240,21 +169,9 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
...
@@ -240,21 +169,9 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
return
;
return
;
}
}
// std::vector<T> sum_rows(num_distributions);
// SumArrayCUDAKernel<T>(in_data, sum_rows,)
VLOG
(
3
)
<<
"Print num_distributions "
<<
num_distributions
<<
"
\n
"
;
VLOG
(
3
)
<<
"Print num_categories "
<<
num_categories
<<
"
\n
"
;
VLOG
(
3
)
<<
"Print in_rank "
<<
in_rank
<<
"
\n
"
;
framework
::
Tensor
sum_rows_t
;
framework
::
Tensor
sum_rows_t
;
auto
*
sum_rows_data
=
auto
*
sum_rows_data
=
sum_rows_t
.
mutable_data
<
T
>
({
num_distributions
},
ctx
.
GetPlace
());
sum_rows_t
.
mutable_data
<
T
>
({
num_distributions
},
ctx
.
GetPlace
());
// auto* sum_rows_data =
// sum_rows_t->mutable_data<T>(framework::make_ddim({num_distributions}),
// ctx.GetPlace());
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>()
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>()
.
eigen_device
();
.
eigen_device
();
...
@@ -262,58 +179,34 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
...
@@ -262,58 +179,34 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
if
(
num_distributions
==
1
)
{
if
(
num_distributions
==
1
)
{
auto
eigen_input
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_input
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_sum_rows
=
framework
::
EigenVector
<
T
>::
From
(
sum_rows_t
);
auto
eigen_sum_rows
=
framework
::
EigenVector
<
T
>::
From
(
sum_rows_t
);
// auto eigen_sum_rows = framework::EigenScalar<T>::From(sum_rows_t);
eigen_sum_rows
.
device
(
place
)
=
eigen_sum_rows
.
device
(
place
)
=
eigen_input
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
eigen_input
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
eval
()
.
eval
()
.
reshape
(
Eigen
::
DSizes
<
int
,
1
>
(
sum_rows_t
.
dims
()[
0
]));
.
reshape
(
Eigen
::
DSizes
<
int
,
1
>
(
sum_rows_t
.
dims
()[
0
]));
}
else
{
}
else
{
auto
eigen_input
=
framework
::
EigenMatrix
<
T
>::
From
(
*
x
);
auto
eigen_input
=
framework
::
EigenMatrix
<
T
>::
From
(
*
x
);
// auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
auto
eigen_sum_rows
=
framework
::
EigenVector
<
T
>::
From
(
sum_rows_t
);
auto
eigen_sum_rows
=
framework
::
EigenVector
<
T
>::
From
(
sum_rows_t
);
eigen_sum_rows
.
device
(
place
)
=
eigen_input
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
));
eigen_sum_rows
.
device
(
place
)
=
eigen_input
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
));
// .eval()
// .reshape(Eigen::DSizes<int, 1>(sum_rows_t.dims()[0]));
// eigen_sum_rows.device(place) =
// eigen_input.sum().eval().reshape(Eigen::DSizes<int, 1>(1));
}
}
// std::vector<T> in_data_norm(num_categories);
framework
::
Tensor
norm_probs_t
;
framework
::
Tensor
norm_probs_t
;
auto
*
norm_probs_data
=
norm_probs_t
.
mutable_data
<
T
>
(
auto
*
norm_probs_data
=
norm_probs_t
.
mutable_data
<
T
>
(
{
num_distributions
,
num_categories
},
ctx
.
GetPlace
());
{
num_distributions
,
num_categories
},
ctx
.
GetPlace
());
// dim3 grid(num_distributions);
// dim3 block(num_categories);
dim3
block
(
num_categories
<
512
?
num_categories
:
512
);
dim3
block
(
num_categories
<
512
?
num_categories
:
512
);
dim3
grid
((
num_categories
-
1
)
/
block
.
x
+
1
,
num_distributions
);
dim3
grid
((
num_categories
-
1
)
/
block
.
x
+
1
,
num_distributions
);
NormalizeProbability
<
NormalizeProbability
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
norm_probs_data
,
in_data
,
sum_rows_data
);
norm_probs_data
,
in_data
,
sum_rows_data
);
// num_distributions can only be 1.
// std::vector<T> cumulative_probs(num_categories);
framework
::
Tensor
cumulative_probs_t
;
framework
::
Tensor
cumulative_probs_t
;
auto
*
cumulative_probs
=
cumulative_probs_t
.
mutable_data
<
T
>
(
auto
*
cumulative_probs
=
cumulative_probs_t
.
mutable_data
<
T
>
(
{
num_distributions
,
num_categories
},
ctx
.
GetPlace
());
{
num_distributions
,
num_categories
},
ctx
.
GetPlace
());
// T cumulative_probs[num_categories];
dim3
block1
(
1
);
dim3
block1
(
1
);
dim3
grid1
(
num_distributions
);
dim3
grid1
(
num_distributions
);
Cumsum
<
T
><<<
grid1
,
block1
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
Cumsum
<
T
><<<
grid1
,
block1
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
norm_probs_data
,
num_distributions
,
num_categories
,
cumulative_probs
);
norm_probs_data
,
num_distributions
,
num_categories
,
cumulative_probs
);
/*
dim3 block2(num_categories < 512 ? num_categories : 512);
dim3 grid2((num_categories-1)/block2.x+1, num_distributions);
yokiFunc<T><<<grid2, block2, 0, ctx.cuda_device_context().stream()>>>(
cumulative_probs, yokiout_data);*/
// int64_t size = num_categories;
// thrust::inclusive_scan(thrust::device, norm_probs_data,
// norm_probs_data + num_categories,
// cumulative_probs);
VLOG
(
3
)
<<
"Print cumsum "
<<
cumulative_probs
<<
"
\n
"
;
VLOG
(
3
)
<<
"Print cumsum "
<<
cumulative_probs
<<
"
\n
"
;
if
(
replacement
)
{
if
(
replacement
)
{
...
@@ -336,24 +229,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
...
@@ -336,24 +229,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
index_sequence_begin
+
num_distributions
*
num_samples
,
rng_data
,
index_sequence_begin
+
num_distributions
*
num_samples
,
rng_data
,
RandomGeneratorCudaFunctor
<
T
>
(
seed
));
RandomGeneratorCudaFunctor
<
T
>
(
seed
));
VLOG
(
3
)
<<
"Print enter
\n
"
;
// VLOG(3) << "Print size in_data " <<
// sizeof(in_data)/sizeof(in_data[num_categories-1]) << "\n";
// VLOG(3) << "Print norm_probs_data0 " <<
// sizeof(norm_probs_data[num_categories-1]) << "\n";
sampleMultinomialWithReplacement
<
sampleMultinomialWithReplacement
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
rng_data
,
num_samples
,
out_data
,
num_distributions
,
num_categories
,
rng_data
,
num_samples
,
out_data
,
num_distributions
,
num_categories
,
cumulative_probs
,
norm_probs_data
);
cumulative_probs
,
norm_probs_data
);
VLOG
(
3
)
<<
"Print end
\n
"
<<
out_data
;
}
}
VLOG
(
3
)
<<
"Print final end
\n
"
;
// MultinomialCudaFunctor<T>(out_data, in_data, num_samples, replacement,
// num_categories, num_distributions);
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_multinomial_op.py
浏览文件 @
80537a1d
...
@@ -126,6 +126,49 @@ class TestMultinomialApi(unittest.TestCase):
...
@@ -126,6 +126,49 @@ class TestMultinomialApi(unittest.TestCase):
sample_prob
,
prob
,
rtol
=
0
,
atol
=
0.01
),
sample_prob
,
prob
,
rtol
=
0
,
atol
=
0.01
),
"sample_prob: "
+
str
(
sample_prob
)
+
"
\n
prob: "
+
str
(
prob
))
"sample_prob: "
+
str
(
sample_prob
)
+
"
\n
prob: "
+
str
(
prob
))
def
test_dygraph2
(
self
):
paddle
.
disable_static
()
x
=
paddle
.
rand
([
3
,
4
])
out
=
paddle
.
multinomial
(
x
,
num_samples
=
100000
,
replacement
=
True
)
x_numpy
=
x
.
numpy
()
out_list
=
np
.
split
(
out
.
numpy
(),
3
,
axis
=
0
)
count_array
=
[
0
]
*
3
for
i
in
range
(
3
):
count_array
[
i
]
=
np
.
unique
(
out_list
[
i
],
return_counts
=
True
)[
1
].
astype
(
"float32"
)
sample_prob
=
np
.
stack
(
count_array
,
axis
=
0
)
sample_prob
/=
sample_prob
.
sum
(
axis
=-
1
,
keepdims
=
True
)
prob
=
x_numpy
/
x_numpy
.
sum
(
axis
=-
1
,
keepdims
=
True
)
self
.
assertTrue
(
np
.
allclose
(
sample_prob
,
prob
,
rtol
=
0
,
atol
=
0.01
),
"sample_prob: "
+
str
(
sample_prob
)
+
"
\n
prob: "
+
str
(
prob
))
paddle
.
enable_static
()
def
test_dygraph3
(
self
):
paddle
.
disable_static
()
x
=
paddle
.
rand
([
1000
])
out
=
paddle
.
multinomial
(
x
,
num_samples
=
100
,
replacement
=
False
)
x_numpy
=
x
.
numpy
()
unique_out
=
np
.
unique
(
out
.
numpy
())
self
.
assertEqual
(
len
(
unique_out
),
100
,
"replacement is False. categories can't be sampled repeatedly"
)
paddle
.
enable_static
()
"""
def test_replacement_error(self):
def test_error():
paddle.disable_static()
x = paddle.rand([5])
out = paddle.multinomial(x, num_samples=10, replacement=False)
self.assertRaises(OutOfRangeError, test_error) # not OutOfRangeError
"""
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/tensor/random.py
浏览文件 @
80537a1d
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO: define random functions
# TODO: define random functions
from
..fluid
import
core
from
..fluid
import
core
from
..fluid.framework
import
in_dygraph_mode
,
Variable
,
convert_np_dtype_to_dtype_
from
..fluid.framework
import
in_dygraph_mode
,
Variable
,
convert_np_dtype_to_dtype_
...
@@ -40,18 +40,18 @@ def bernoulli(x, name=None):
...
@@ -40,18 +40,18 @@ def bernoulli(x, name=None):
This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution.
This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution.
The input ``x`` is a tensor with probabilities for generating the random binary number.
The input ``x`` is a tensor with probabilities for generating the random binary number.
Each element in ``x`` should be in [0, 1], and the out is generated by:
Each element in ``x`` should be in [0, 1], and the out is generated by:
.. math::
.. math::
out_i ~ Bernoulli (x_i)
out_i ~ Bernoulli (x_i)
Args:
Args:
x(Tensor): A tensor with probabilities for generating the random binary number. The data type
x(Tensor): A tensor with probabilities for generating the random binary number. The data type
should be float32, float64.
should be float32, float64.
name(str, optional): The default value is None. Normally there is no
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
refer to :ref:`api_guide_Name`.
Returns:
Returns:
Tensor: A Tensor filled with random binary number with the same shape and dtype as ``x``.
Tensor: A Tensor filled with random binary number with the same shape and dtype as ``x``.
Examples:
Examples:
...
@@ -80,7 +80,7 @@ def bernoulli(x, name=None):
...
@@ -80,7 +80,7 @@ def bernoulli(x, name=None):
helper
=
LayerHelper
(
"randint"
,
**
locals
())
helper
=
LayerHelper
(
"randint"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
# maybe set out to int32 ?
dtype
=
x
.
dtype
)
# maybe set out to int32 ?
helper
.
append_op
(
helper
.
append_op
(
type
=
'bernoulli'
,
inputs
=
{
"X"
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{})
type
=
'bernoulli'
,
inputs
=
{
"X"
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{})
return
out
return
out
...
@@ -88,8 +88,23 @@ def bernoulli(x, name=None):
...
@@ -88,8 +88,23 @@ def bernoulli(x, name=None):
def
multinomial
(
x
,
num_samples
=
1
,
replacement
=
False
,
name
=
None
):
def
multinomial
(
x
,
num_samples
=
1
,
replacement
=
False
,
name
=
None
):
"""
"""
This OP returns a Tensor filled with random values sampled from a Multinomical
distribution. The input ``x`` is a tensor with probabilities for generating the
random number. Each element in ``x`` should be larger or equal to 0, but not all
0. ``replacement`` indicates whether it is a replaceable sample. If ``replacement``
is True, a category can be sampled more than once.
Args:
x(Tensor): A tensor with probabilities for generating the random number. The data type
should be float32, float64.
num_samples(int, optional): Number of samples, default is 1.
replacement(bool, optional): whether it is a replaceable sample, default is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with sampled category index after ``num_samples`` times samples.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -97,15 +112,24 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
...
@@ -97,15 +112,24 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
paddle.disable_static()
paddle.disable_static()
x = paddle.rand([2,
3
])
x = paddle.rand([2,
4
])
print(x.numpy())
print(x.numpy())
# [[0.
11272584 0.3890902 0.7730957
]
# [[0.
7713825 0.4055941 0.433339 0.70706886
]
# [0.
10351662 0.8510418 0.63806665
]]
# [0.
9223313 0.8519825 0.04574518 0.16560672
]]
out = paddle.bernoulli(x)
out1 = paddle.multinomial(x, num_samples=5, replacement=True)
print(out.numpy())
print(out1.numpy())
# [[0. 0. 1.]
# [[3. 3. 1. 1. 0.]
# [0. 0. 1.]]
# [0. 0. 0. 0. 1.]]
out2 = paddle.multinomial(x, num_samples=5)
# OutOfRangeError: When replacement is False, number of samples
# should be less than non-zero categories
out3 = paddle.multinomial(x, num_samples=3)
print(out3.numpy())
# [[0. 2. 3.]
# [0. 1. 3.]]
"""
"""
...
@@ -152,7 +176,7 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
...
@@ -152,7 +176,7 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
Returns:
Returns:
Tensor: A Tensor filled with random values sampled from a Gaussian
Tensor: A Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``.
distribution, with ``shape`` and ``dtype``.
"""
"""
op_type_for_check
=
'gaussian/standard_normal/randn/normal'
op_type_for_check
=
'gaussian/standard_normal/randn/normal'
seed
=
0
seed
=
0
...
@@ -393,7 +417,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
...
@@ -393,7 +417,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
Examples:
Examples:
.. code-block:: python
.. code-block:: python
import paddle
import paddle
paddle.disable_static()
paddle.disable_static()
...
@@ -481,7 +505,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
...
@@ -481,7 +505,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
need for user to set this property. For more information, please
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
refer to :ref:`api_guide_Name`.
Returns:
Returns:
Tensor: A Tensor filled with random integers from a discrete uniform
Tensor: A Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
...
@@ -591,7 +615,7 @@ def randperm(n, dtype="int64", name=None):
...
@@ -591,7 +615,7 @@ def randperm(n, dtype="int64", name=None):
out2 = paddle.randperm(7, 'int32')
out2 = paddle.randperm(7, 'int32')
# [1, 6, 2, 0, 4, 3, 5] # random
# [1, 6, 2, 0, 4, 3, 5] # random
"""
"""
if
not
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
if
not
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录