Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
dab6fa97
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dab6fa97
编写于
9月 16, 2020
作者:
P
pangyoki
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cuda kernrl with num_distribution is 1, and not support replacement=False
上级
8dd56af4
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
298 addition
and
5 deletion
+298
-5
paddle/fluid/operators/multinomial_op.cc
paddle/fluid/operators/multinomial_op.cc
+2
-2
paddle/fluid/operators/multinomial_op.cu
paddle/fluid/operators/multinomial_op.cu
+285
-0
paddle/fluid/operators/multinomial_op.h
paddle/fluid/operators/multinomial_op.h
+0
-1
python/paddle/fluid/tests/unittests/test_multinomial_op.py
python/paddle/fluid/tests/unittests/test_multinomial_op.py
+11
-2
未找到文件。
paddle/fluid/operators/multinomial_op.cc
浏览文件 @
dab6fa97
...
@@ -83,7 +83,7 @@ class MultinomialOpKernel<platform::CPUDeviceContext, T>
...
@@ -83,7 +83,7 @@ class MultinomialOpKernel<platform::CPUDeviceContext, T>
const
int64_t
num_categories
=
in_dims
[
in_rank
-
1
];
const
int64_t
num_categories
=
in_dims
[
in_rank
-
1
];
const
int64_t
num_distributions
=
in_rank
>
1
?
in_dims
[
in_rank
-
2
]
:
1
;
const
int64_t
num_distributions
=
in_rank
>
1
?
in_dims
[
in_rank
-
2
]
:
1
;
MultinomialFunctor
(
out_data
,
in_data
,
num_samples
,
replacement
,
MultinomialFunctor
<
T
>
(
out_data
,
in_data
,
num_samples
,
replacement
,
num_categories
,
num_distributions
);
num_categories
,
num_distributions
);
}
}
};
};
...
...
paddle/fluid/operators/multinomial_op.cu
0 → 100644
浏览文件 @
dab6fa97
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/multinomial_op.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
operators
{
/*
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
*/
/*
template <class T>
__global__ void SumArrayCUDAKernel(T **in, T *out, size_t in_size) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
// T total(read_dst ? out[id] : static_cast<T>(0));
T total(static_cast<T>(0))
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[id];
}
}
out[id] = total;
id += blockDim.x * gridDim.x;
}*/
/*
template <typename T>
__global__ void NormalizeProbability(T* probs, int64_t rows, int64_t cols) {
extern __shared__ std::vector<T> sum_rows(rows);
T val;
for (int64_t i = blockId.x; i < rows; i += gridDim.x) {
T sum = static_cast<T>(0);
for (int64_t j = threadIdx.x; j < cols; j += blockDim.x) {
val = probs[i * cols + j];
sum += val;
}
}
}*/
template
<
typename
T
>
__global__
void
NormalizeProbability
(
T
*
norm_probs
,
const
T
*
in_data
,
T
*
sum_rows
)
{
// int id = blockIdx.x * blockDim.x + threadIdx.x;
int
id
=
threadIdx
.
x
;
norm_probs
[
id
]
=
in_data
[
id
]
/
sum_rows
[
0
];
}
template
<
typename
T
>
struct
RandomGeneratorCudaFunctor
{
unsigned
int
seed_
;
__host__
__device__
RandomGeneratorCudaFunctor
(
int
seed
)
:
seed_
(
seed
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
0.0
,
1.0
);
rng
.
discard
(
n
);
return
dist
(
rng
);
}
};
/*
template <typename T>
class MultinomialCudaFunctor(T* out_data, const T* in_data,
const int64_t num_samples, const bool replacement,
const int64_t num_categories,
const int64_t num_distributions) {
}*/
template
<
typename
T
>
__device__
int
binarySearchForMultinomial
(
T
*
cumdist
,
T
*
dist
,
int
size
,
T
val
)
{
int
start
=
0
;
int
end
=
size
;
// cumdist[size - 1] = 0 => all zero prob dist
// CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0));
while
(
end
-
start
>
0
)
{
int
mid
=
start
+
(
end
-
start
)
/
2
;
T
midVal
=
cumdist
[
mid
];
if
(
midVal
<
val
)
{
start
=
mid
+
1
;
}
else
{
end
=
mid
;
}
}
if
(
start
==
size
)
{
// No probability mass or precision problems; just return the
// first non-zero element by setting start to size-1 here,
// the code below will move it to the last non-zero probability
// this actually can happen when the random number is 1
// (github pytorch issue #4858).
start
=
size
-
1
;
}
while
(
start
>=
1
&&
dist
[
start
]
==
0
)
start
--
;
return
start
;
}
template
<
typename
T
>
__global__
void
sampleMultinomialWithReplacement
(
T
*
rng
,
const
int64_t
totalSamples
,
T
*
dest
,
const
int64_t
distributions
,
const
int64_t
categories
,
T
*
normDistPrefixSum
,
T
*
normDist
)
{
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on.
// global index formula for 2D grid of 1D blocks
// int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x +
// threadIdx.x;
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
sample
<
totalSamples
;
sample
+=
blockDim
.
x
*
gridDim
.
x
)
{
// we are losing 3 out of 4 generated numbers but it's ok
// this kernel is not very efficient anyway
// T uniform_random = dist(rng);
T
uniform_random
=
rng
[
sample
];
// Find the bucket that a uniform sample lies in
int
choice
=
binarySearchForMultinomial
<
T
>
(
normDistPrefixSum
,
normDist
,
categories
,
uniform_random
);
dest
[
sample
]
=
choice
;
}
}
template
<
typename
T
>
class
MultinomialOpKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
const
int64_t
num_samples
=
ctx
.
Attr
<
int
>
(
"num_samples"
);
const
bool
replacement
=
ctx
.
Attr
<
bool
>
(
"replacement"
);
auto
*
in_data
=
x
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
in_dims
=
x
->
dims
();
int64_t
in_rank
=
in_dims
.
size
();
const
int64_t
num_categories
=
in_dims
[
in_rank
-
1
];
const
int64_t
num_distributions
=
in_rank
>
1
?
in_dims
[
in_rank
-
2
]
:
1
;
// std::vector<T> sum_rows(num_distributions);
// SumArrayCUDAKernel<T>(in_data, sum_rows,)
VLOG
(
3
)
<<
"Print num_distributions "
<<
num_distributions
<<
"
\n
"
;
VLOG
(
3
)
<<
"Print num_categories "
<<
num_categories
<<
"
\n
"
;
VLOG
(
3
)
<<
"Print in_rank "
<<
in_rank
<<
"
\n
"
;
framework
::
Tensor
sum_rows_t
;
auto
*
sum_rows_data
=
sum_rows_t
.
mutable_data
<
T
>
({
1
},
ctx
.
GetPlace
());
// auto* sum_rows_data =
// sum_rows_t->mutable_data<T>(framework::make_ddim({1}), ctx.GetPlace());
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>()
.
eigen_device
();
auto
eigen_input
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
// auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
auto
eigen_sum_rows
=
framework
::
EigenScalar
<
T
>::
From
(
sum_rows_t
);
eigen_sum_rows
.
device
(
place
)
=
eigen_input
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
0
))
.
eval
()
.
reshape
(
Eigen
::
DSizes
<
int
,
1
>
(
sum_rows_t
.
dims
()[
0
]));
// eigen_sum_rows.device(place) =
// eigen_input.sum().eval().reshape(Eigen::DSizes<int, 1>(1));
dim3
grid
(
num_distributions
);
dim3
block
(
num_categories
);
// std::vector<T> in_data_norm(num_categories);
framework
::
Tensor
norm_probs_t
;
auto
*
norm_probs_data
=
norm_probs_t
.
mutable_data
<
T
>
({
num_categories
},
ctx
.
GetPlace
());
NormalizeProbability
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
norm_probs_data
,
in_data
,
sum_rows_data
);
// num_distributions can only be 1.
// std::vector<T> cumulative_probs(num_categories);
framework
::
Tensor
cumulative_probs_t
;
auto
*
cumulative_probs
=
cumulative_probs_t
.
mutable_data
<
T
>
({
num_categories
},
ctx
.
GetPlace
());
// T cumulative_probs[num_categories];
int64_t
size
=
num_categories
;
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
,
norm_probs_data
+
num_categories
,
cumulative_probs
);
if
(
replacement
)
{
dim3
block
(
128
);
// int grid_y = 1;
dim3
grid
((
num_samples
-
1
)
/
block
.
x
+
1
);
/*
// std::vector<T> rng(num_samples);
T rng[num_samples];
std::uniform_real_distribution<T> dist(0, 1);
auto gen_ptr = framework::DefaultCPUGenerator();
auto engine = gen_ptr->GetCPUEngine();
for (int s = 0; s < num_samples; s++) {
rng[s] = dist(*engine);
}
*/
std
::
random_device
rd
;
auto
seed
=
rd
();
framework
::
Tensor
rng_data_t
;
auto
*
rng_data
=
rng_data_t
.
mutable_data
<
T
>
({
num_samples
},
ctx
.
GetPlace
());
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
platform
::
Transform
<
platform
::
CUDADeviceContext
>
trans
;
auto
*
context
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
&
ctx
.
device_context
());
trans
(
*
context
,
index_sequence_begin
,
index_sequence_begin
+
num_samples
,
rng_data
,
RandomGeneratorCudaFunctor
<
T
>
(
seed
));
VLOG
(
3
)
<<
"Print enter
\n
"
;
// VLOG(3) << "Print size in_data " <<
// sizeof(in_data)/sizeof(in_data[num_categories-1]) << "\n";
// VLOG(3) << "Print norm_probs_data0 " <<
// sizeof(norm_probs_data[num_categories-1]) << "\n";
sampleMultinomialWithReplacement
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
rng_data
,
num_samples
,
out_data
,
num_distributions
,
num_categories
,
cumulative_probs
,
norm_probs_data
);
}
// MultinomialCudaFunctor<T>(out_data, in_data, num_samples, replacement,
// num_categories, num_distributions);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
multinomial
,
ops
::
MultinomialOpKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
MultinomialOpKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/multinomial_op.h
浏览文件 @
dab6fa97
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
python/paddle/fluid/tests/unittests/test_multinomial_op.py
浏览文件 @
dab6fa97
...
@@ -26,6 +26,14 @@ class TestMultinomialOp(OpTest):
...
@@ -26,6 +26,14 @@ class TestMultinomialOp(OpTest):
self
.
init_data
()
self
.
init_data
()
self
.
inputs
=
{
"X"
:
self
.
input_np
}
self
.
inputs
=
{
"X"
:
self
.
input_np
}
"""
def init_data(self):
# input probability is a vector, and replacement is True
self.input_np = np.random.rand(4)
self.outputs = {"Out": np.zeros(100000).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}
"""
def
init_data
(
self
):
def
init_data
(
self
):
# input probability is a vector, and replacement is True
# input probability is a vector, and replacement is True
self
.
input_np
=
np
.
random
.
rand
(
4
)
self
.
input_np
=
np
.
random
.
rand
(
4
)
...
@@ -45,12 +53,14 @@ class TestMultinomialOp(OpTest):
...
@@ -45,12 +53,14 @@ class TestMultinomialOp(OpTest):
# normalize the input to get the probability
# normalize the input to get the probability
prob
=
self
.
input_np
/
self
.
input_np
.
sum
(
axis
=-
1
,
keepdims
=
True
)
prob
=
self
.
input_np
/
self
.
input_np
.
sum
(
axis
=-
1
,
keepdims
=
True
)
sample_prob
=
self
.
sample_output
(
np
.
array
(
outs
[
0
]))
sample_prob
=
self
.
sample_output
(
np
.
array
(
outs
[
0
]))
print
(
"sample_prob: "
+
str
(
sample_prob
)
+
"
\n
prob: "
+
str
(
prob
))
self
.
assertTrue
(
self
.
assertTrue
(
np
.
allclose
(
np
.
allclose
(
sample_prob
,
prob
,
rtol
=
0
,
atol
=
0.01
),
sample_prob
,
prob
,
rtol
=
0
,
atol
=
0.01
),
"sample_prob: "
+
str
(
sample_prob
)
+
"
\n
prob: "
+
str
(
prob
))
"sample_prob: "
+
str
(
sample_prob
)
+
"
\n
prob: "
+
str
(
prob
))
"""
class TestMultinomialOp2(TestMultinomialOp):
class TestMultinomialOp2(TestMultinomialOp):
def init_data(self):
def init_data(self):
# input probability is a matrix
# input probability is a matrix
...
@@ -82,8 +92,7 @@ class TestMultinomialOp3(TestMultinomialOp):
...
@@ -82,8 +92,7 @@ class TestMultinomialOp3(TestMultinomialOp):
self.assertEqual(
self.assertEqual(
len(unique_out), 100,
len(unique_out), 100,
"replacement is False. categories can't be sampled repeatedly")
"replacement is False. categories can't be sampled repeatedly")
"""
"""
"""
class TestReplacementError(unittest.TestCase):
class TestReplacementError(unittest.TestCase):
def init_data(self):
def init_data(self):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录