Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f373269d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f373269d
编写于
9月 30, 2020
作者:
Q
Qi Li
提交者:
GitHub
9月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update histogram op for performance optimization, test=develop (#24912)
上级
4d5ddbf1
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
90 addition
and
45 deletion
+90
-45
paddle/fluid/operators/histogram_op.cu
paddle/fluid/operators/histogram_op.cu
+29
-20
python/paddle/fluid/tests/unittests/test_histogram_op.py
python/paddle/fluid/tests/unittests/test_histogram_op.py
+56
-2
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+5
-23
未找到文件。
paddle/fluid/operators/histogram_op.cu
浏览文件 @
f373269d
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/histogram_op.h"
#include "paddle/fluid/operators/histogram_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
@@ -32,28 +30,38 @@ inline int GET_BLOCKS(const int N) {
...
@@ -32,28 +30,38 @@ inline int GET_BLOCKS(const int N) {
}
}
template
<
typename
T
,
typename
IndexType
>
template
<
typename
T
,
typename
IndexType
>
__device__
static
IndexType
GetBin
(
T
bVal
,
T
minvalue
,
T
max
value
,
__device__
static
IndexType
GetBin
(
T
input_value
,
T
min_value
,
T
max_
value
,
int64_t
nbins
)
{
int64_t
nbins
)
{
IndexType
bin
=
IndexType
bin
=
static_cast
<
int
>
((
input_value
-
min_value
)
*
nbins
/
static_cast
<
int
>
((
bVal
-
minvalue
)
*
nbins
/
(
maxvalue
-
min
value
));
(
max_value
-
min_
value
));
if
(
bin
==
nbins
)
bin
-=
1
;
IndexType
output_index
=
bin
<
nbins
-
1
?
bin
:
nbins
-
1
;
return
bin
;
return
output_index
;
}
}
template
<
typename
T
,
typename
IndexType
>
template
<
typename
T
,
typename
IndexType
>
__global__
void
KernelHistogram
(
const
T
*
input
,
const
int
totalElements
,
__global__
void
KernelHistogram
(
const
T
*
input
,
const
int
total_elements
,
const
int64_t
nbins
,
const
T
minvalue
,
const
int64_t
nbins
,
const
T
min_value
,
const
T
maxvalue
,
int64_t
*
output
)
{
const
T
max_value
,
int64_t
*
output
)
{
CUDA_KERNEL_LOOP
(
linearIndex
,
totalElements
)
{
extern
__shared__
int64_t
buf_hist
[];
const
IndexType
inputIdx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nbins
;
i
+=
blockDim
.
x
)
{
const
auto
inputVal
=
input
[
inputIdx
];
buf_hist
[
i
]
=
0
;
if
(
inputVal
>=
minvalue
&&
inputVal
<=
maxvalue
)
{
}
const
IndexType
bin
=
__syncthreads
();
GetBin
<
T
,
IndexType
>
(
inputVal
,
minvalue
,
maxvalue
,
nbins
);
const
IndexType
outputIdx
=
bin
<
nbins
-
1
?
bin
:
nbins
-
1
;
CUDA_KERNEL_LOOP
(
input_index
,
total_elements
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
output
[
outputIdx
],
1
);
// const IndexType input_index = threadIdx.x + blockIdx.x * blockDim.x;
const
auto
input_value
=
input
[
input_index
];
if
(
input_value
>=
min_value
&&
input_value
<=
max_value
)
{
const
IndexType
output_index
=
GetBin
<
T
,
IndexType
>
(
input_value
,
min_value
,
max_value
,
nbins
);
paddle
::
platform
::
CudaAtomicAdd
(
&
buf_hist
[
output_index
],
1
);
}
}
}
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
nbins
;
i
+=
blockDim
.
x
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
output
[
i
],
buf_hist
[
i
]);
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
...
@@ -125,8 +133,9 @@ class HistogramCUDAKernel : public framework::OpKernel<T> {
...
@@ -125,8 +133,9 @@ class HistogramCUDAKernel : public framework::OpKernel<T> {
auto
stream
=
auto
stream
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
context
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
();
KernelHistogram
<
T
,
IndexType
><<<
GET_BLOCKS
(
input_numel
),
KernelHistogram
<
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
T
,
IndexType
><<<
GET_BLOCKS
(
input_numel
),
PADDLE_CUDA_NUM_THREADS
,
nbins
*
sizeof
(
int64_t
),
stream
>>>
(
input_data
,
input_numel
,
nbins
,
output_min
,
output_max
,
out_data
);
input_data
,
input_numel
,
nbins
,
output_min
,
output_max
,
out_data
);
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_histogram_op.py
浏览文件 @
f373269d
...
@@ -58,12 +58,66 @@ class TestHistogramOpAPI(unittest.TestCase):
...
@@ -58,12 +58,66 @@ class TestHistogramOpAPI(unittest.TestCase):
msg
=
'histogram output is wrong, out ='
+
str
(
actual
.
numpy
()))
msg
=
'histogram output is wrong, out ='
+
str
(
actual
.
numpy
()))
class
TestHistogramOpError
(
unittest
.
TestCase
):
"""Test histogram op error."""
def
run_network
(
self
,
net_func
):
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
,
startup_program
):
net_func
()
exe
=
fluid
.
Executor
()
exe
.
run
(
main_program
)
def
test_bins_error
(
self
):
"""Test bins should be greater than or equal to 1."""
def
net_func
():
input_value
=
paddle
.
fill_constant
(
shape
=
[
3
,
4
],
dtype
=
'float32'
,
value
=
3.0
)
paddle
.
histogram
(
input
=
input_value
,
bins
=-
1
,
min
=
1
,
max
=
5
)
with
self
.
assertRaises
(
fluid
.
core
.
EnforceNotMet
):
self
.
run_network
(
net_func
)
def
test_min_max_error
(
self
):
"""Test max must be larger or equal to min."""
def
net_func
():
input_value
=
paddle
.
fill_constant
(
shape
=
[
3
,
4
],
dtype
=
'float32'
,
value
=
3.0
)
paddle
.
histogram
(
input
=
input_value
,
bins
=
1
,
min
=
5
,
max
=
1
)
with
self
.
assertRaises
(
fluid
.
core
.
EnforceNotMet
):
self
.
run_network
(
net_func
)
def
test_min_max_range_error
(
self
):
"""Test range of min, max is not finite"""
def
net_func
():
input_value
=
paddle
.
fill_constant
(
shape
=
[
3
,
4
],
dtype
=
'float32'
,
value
=
3.0
)
paddle
.
histogram
(
input
=
input_value
,
bins
=
1
,
min
=-
np
.
inf
,
max
=
5
)
with
self
.
assertRaises
(
fluid
.
core
.
EnforceNotMet
):
self
.
run_network
(
net_func
)
def
test_type_errors
(
self
):
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
paddle
.
histogram
,
1
,
bins
=
5
,
min
=
1
,
max
=
5
)
# The input type must be 'int32', 'int64', 'float32', 'float64'
x_bool
=
fluid
.
data
(
name
=
'x_bool'
,
shape
=
[
4
,
3
],
dtype
=
'bool'
)
self
.
assertRaises
(
TypeError
,
paddle
.
histogram
,
x_bool
,
bins
=
5
,
min
=
1
,
max
=
5
)
class
TestHistogramOp
(
OpTest
):
class
TestHistogramOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"histogram"
self
.
op_type
=
"histogram"
self
.
init_test_case
()
self
.
init_test_case
()
np_input
=
np
.
random
.
randint
(
np_input
=
np
.
random
.
uniform
(
low
=
0.0
,
high
=
20.0
,
size
=
self
.
in_shape
)
low
=
0
,
high
=
20
,
size
=
self
.
in_shape
,
dtype
=
np
.
int64
)
self
.
inputs
=
{
"X"
:
np_input
}
self
.
inputs
=
{
"X"
:
np_input
}
self
.
init_attrs
()
self
.
init_attrs
()
Out
,
_
=
np
.
histogram
(
Out
,
_
=
np
.
histogram
(
...
...
python/paddle/tensor/linalg.py
浏览文件 @
f373269d
...
@@ -862,41 +862,23 @@ def histogram(input, bins=100, min=0, max=0):
...
@@ -862,41 +862,23 @@ def histogram(input, bins=100, min=0, max=0):
If min and max are both zero, the minimum and maximum values of the data are used.
If min and max are both zero, the minimum and maximum values of the data are used.
Args:
Args:
input (
Variable
): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
input (
Tensor
): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
should be float32, float64, int32, int64.
should be float32, float64, int32, int64.
bins (int): number of histogram bins
bins (int): number of histogram bins
min (int): lower end of the range (inclusive)
min (int): lower end of the range (inclusive)
max (int): upper end of the range (inclusive)
max (int): upper end of the range (inclusive)
Returns:
Returns:
Variable: Tensor or LoDTensor calculated by histogram layer. The data type is int64
.
Tensor: data type is int64, shape is (nbins,)
.
Code Example 1:
Examples:
.. code-block:: python
import paddle
import numpy as np
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
inputs = paddle.data(name='input', dtype='int32', shape=[2,3])
output = paddle.histogram(inputs, bins=5, min=1, max=5)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
img = np.array([[2, 4, 2], [2, 5, 4]]).astype(np.int32)
res = exe.run(train_program,
feed={'input': img},
fetch_list=[output])
print(np.array(res[0])) # [0,3,0,2,1]
Code Example 2:
.. code-block:: python
.. code-block:: python
import paddle
import paddle
paddle.disable_static(paddle.CPUPlace())
inputs = paddle.to_tensor([1, 2, 1])
inputs = paddle.to_tensor([1, 2, 1])
result = paddle.histogram(inputs, bins=4, min=0, max=3)
result = paddle.histogram(inputs, bins=4, min=0, max=3)
print(result) # [0, 2, 1, 0]
print(result) # [0, 2, 1, 0]
paddle.enable_static()
"""
"""
if
in_dygraph_mode
():
if
in_dygraph_mode
():
return
core
.
ops
.
histogram
(
input
,
"bins"
,
bins
,
"min"
,
min
,
"max"
,
max
)
return
core
.
ops
.
histogram
(
input
,
"bins"
,
bins
,
"min"
,
min
,
"max"
,
max
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录