Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5076ad87
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5076ad87
编写于
2月 13, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move histogram to pten; test=develop
上级
74a150fe
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
173 addition
and
69 deletion
+173
-69
.gitignore
.gitignore
+2
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+2
-0
paddle/fluid/imperative/prepared_operator.h
paddle/fluid/imperative/prepared_operator.h
+2
-0
paddle/fluid/operators/histogram_op.cc
paddle/fluid/operators/histogram_op.cc
+6
-6
paddle/pten/kernels/cpu/histogram_kernel.cc
paddle/pten/kernels/cpu/histogram_kernel.cc
+77
-0
paddle/pten/kernels/gpu/histogram_kernel.cu
paddle/pten/kernels/gpu/histogram_kernel.cu
+53
-63
paddle/pten/kernels/histogram_kernel.h
paddle/pten/kernels/histogram_kernel.h
+16
-0
paddle/pten/ops/compat/histogram_sig.cc
paddle/pten/ops/compat/histogram_sig.cc
+15
-0
未找到文件。
.gitignore
浏览文件 @
5076ad87
...
...
@@ -51,3 +51,5 @@ paddle/infrt/dialect/pd_ops_info.h
.lit_test_times.txt
paddle/infrt/tests/dialect/Output
paddle/infrt/tests/lit.cfg.py
paddle/fluid/pybind/eager_final_state_op_function_impl.h
paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h
paddle/fluid/framework/operator.cc
浏览文件 @
5076ad87
...
...
@@ -2155,6 +2155,8 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
float
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
bool
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
bool
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
int64_t
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
int64_t
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
string
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
std
::
string
,
attr
));
...
...
paddle/fluid/imperative/prepared_operator.h
浏览文件 @
5076ad87
...
...
@@ -412,6 +412,8 @@ void BuildDygraphPtenKernelContext(
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
float
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
bool
)))
{
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
bool
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
int64_t
)))
{
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
int64_t
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
string
)))
{
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
std
::
string
,
attr
));
...
...
paddle/fluid/operators/histogram_op.cc
浏览文件 @
5076ad87
...
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/
operators/histogram_op
.h"
#include "paddle/fluid/
framework/op_registry
.h"
#include <string>
#include <unordered_map>
...
...
@@ -85,8 +85,8 @@ REGISTER_OPERATOR(
histogram
,
ops
::
HistogramOp
,
ops
::
HistogramOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
histogram
,
ops
::
HistogramKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
HistogramKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
HistogramKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
HistogramKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
//
REGISTER_OP_CPU_KERNEL(
//
histogram, ops::HistogramKernel<paddle::platform::CPUDeviceContext, float>,
//
ops::HistogramKernel<paddle::platform::CPUDeviceContext, double>,
//
ops::HistogramKernel<paddle::platform::CPUDeviceContext, int>,
//
ops::HistogramKernel<paddle::platform::CPUDeviceContext, int64_t>);
paddle/
fluid/operators/histogram_op.h
→
paddle/
pten/kernels/cpu/histogram_kernel.cc
浏览文件 @
5076ad87
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
using
Tensor
=
framework
::
Tensor
;
namespace
pten
{
template
<
typename
DeviceContext
,
typename
T
>
class
HistogramKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
Tensor
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
&
nbins
=
context
.
Attr
<
int64_t
>
(
"bins"
);
auto
&
minval
=
context
.
Attr
<
int
>
(
"min"
);
auto
&
maxval
=
context
.
Attr
<
int
>
(
"max"
);
template
<
typename
T
,
typename
Context
>
void
HistogramKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
int64_t
bins
,
int
min
,
int
max
,
DenseTensor
*
output
)
{
auto
&
nbins
=
bins
;
auto
&
minval
=
min
;
auto
&
maxval
=
max
;
const
T
*
input_data
=
input
->
data
<
T
>
();
auto
input_numel
=
input
->
numel
();
const
T
*
input_data
=
input
.
data
<
T
>
();
auto
input_numel
=
input
.
numel
();
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
pten
::
funcs
::
SetConstant
<
Device
Context
,
int64_t
>
()(
context
.
template
device_context
<
DeviceContext
>()
,
output
,
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
dev_ctx
.
GetPlace
());
pten
::
funcs
::
SetConstant
<
Context
,
int64_t
>
()(
dev_ctx
,
output
,
static_cast
<
int64_t
>
(
0
));
if
(
input_data
==
nullptr
)
return
;
...
...
@@ -61,10 +46,10 @@ class HistogramKernel : public framework::OpKernel<T> {
std
::
isnan
(
static_cast
<
float
>
(
output_max
))
||
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))),
false
,
p
latform
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
false
,
p
ten
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
PADDLE_ENFORCE_GE
(
output_max
,
output_min
,
p
latform
::
errors
::
InvalidArgument
(
p
ten
::
errors
::
InvalidArgument
(
"max must be larger or equal to min. If min and max are both zero, "
"the minimum and maximum values of the data are used. "
"But received max is %d, min is %d"
,
...
...
@@ -77,8 +62,16 @@ class HistogramKernel : public framework::OpKernel<T> {
out_data
[
std
::
min
(
bin
,
nbins
-
1
)]
+=
1
;
}
}
}
};
}
}
// namspace pten
}
// namespace operators
}
// namespace paddle
PT_REGISTER_KERNEL
(
histogram
,
CPU
,
ALL_LAYOUT
,
pten
::
HistogramKernel
,
float
,
double
,
int
,
int64_t
)
{}
\ No newline at end of file
paddle/
fluid/operators/histogram_op
.cu
→
paddle/
pten/kernels/gpu/histogram_kernel
.cu
浏览文件 @
5076ad87
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/kernels/funcs/math_function.h"
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/histogram_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/core/hostdevice.h"
namespace
paddle
{
namespace
operators
{
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
namespace
pten
{
using
IndexType
=
int64_t
;
using
Tensor
=
framework
::
Tensor
;
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
inline
int
GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
PADDLE_CUDA_NUM_THREADS
-
1
)
/
PADDLE_CUDA_NUM_THREADS
;
...
...
@@ -64,26 +56,24 @@ __global__ void KernelHistogram(const T* input, const int total_elements,
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
HistogramCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
true
,
platform
::
errors
::
InvalidArgument
(
"It must use CUDAPlace."
));
const
Tensor
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
Tensor
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
&
nbins
=
context
.
Attr
<
int64_t
>
(
"bins"
);
auto
&
minval
=
context
.
Attr
<
int
>
(
"min"
);
auto
&
maxval
=
context
.
Attr
<
int
>
(
"max"
);
const
T
*
input_data
=
input
->
data
<
T
>
();
const
int
input_numel
=
input
->
numel
();
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
pten
::
funcs
::
SetConstant
<
platform
::
CUDADeviceContext
,
int64_t
>
()(
context
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
output
,
template
<
typename
T
,
typename
Context
>
void
HistogramKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
int64_t
bins
,
int
min
,
int
max
,
DenseTensor
*
output
)
{
auto
&
nbins
=
bins
;
auto
&
minval
=
min
;
auto
&
maxval
=
max
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
int
input_numel
=
input
.
numel
();
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
dev_ctx
.
GetPlace
());
pten
::
funcs
::
SetConstant
<
Context
,
int64_t
>
()(
dev_ctx
,
output
,
static_cast
<
int64_t
>
(
0
));
if
(
input_data
==
nullptr
)
return
;
...
...
@@ -92,25 +82,25 @@ class HistogramCUDAKernel : public framework::OpKernel<T> {
T
output_max
=
static_cast
<
T
>
(
maxval
);
if
(
output_min
==
output_max
)
{
auto
input_x
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
input
);
auto
input_x
=
pten
::
EigenVector
<
T
>::
Flatten
(
input
);
framework
::
Tensor
input_min_t
,
input_max_t
;
Dense
Tensor
input_min_t
,
input_max_t
;
auto
*
input_min_data
=
input_min_t
.
mutable_data
<
T
>
({
1
},
context
.
GetPlace
());
input_min_t
.
mutable_data
<
T
>
({
1
},
dev_ctx
.
GetPlace
());
auto
*
input_max_data
=
input_max_t
.
mutable_data
<
T
>
({
1
},
context
.
GetPlace
());
auto
input_min_scala
=
framework
::
EigenScalar
<
T
>::
From
(
input_min_t
);
auto
input_max_scala
=
framework
::
EigenScalar
<
T
>::
From
(
input_max_t
);
input_max_t
.
mutable_data
<
T
>
({
1
},
dev_ctx
.
GetPlace
());
auto
input_min_scala
=
pten
::
EigenScalar
<
T
>::
From
(
input_min_t
);
auto
input_max_scala
=
pten
::
EigenScalar
<
T
>::
From
(
input_max_t
);
auto
*
place
=
context
.
template
device_context
<
DeviceContext
>()
.
eigen_device
();
dev_ctx
.
eigen_device
();
input_min_scala
.
device
(
*
place
)
=
input_x
.
minimum
();
input_max_scala
.
device
(
*
place
)
=
input_x
.
maximum
();
Tensor
input_min_cpu
,
input_max_cpu
;
paddle
::
framework
::
TensorCopySync
(
input_min_t
,
platform
::
CPUPlace
(),
Dense
Tensor
input_min_cpu
,
input_max_cpu
;
paddle
::
framework
::
TensorCopySync
(
input_min_t
,
p
addle
::
p
latform
::
CPUPlace
(),
&
input_min_cpu
);
paddle
::
framework
::
TensorCopySync
(
input_max_t
,
platform
::
CPUPlace
(),
paddle
::
framework
::
TensorCopySync
(
input_max_t
,
p
addle
::
p
latform
::
CPUPlace
(),
&
input_max_cpu
);
output_min
=
input_min_cpu
.
data
<
T
>
()[
0
];
...
...
@@ -126,31 +116,31 @@ class HistogramCUDAKernel : public framework::OpKernel<T> {
std
::
isnan
(
static_cast
<
float
>
(
output_max
))
||
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))),
false
,
p
latform
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
false
,
p
ten
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
PADDLE_ENFORCE_GE
(
output_max
,
output_min
,
p
latform
::
errors
::
InvalidArgument
(
p
ten
::
errors
::
InvalidArgument
(
"max must be larger or equal to min. If min and max are both zero, "
"the minimum and maximum values of the data are used. "
"But received max is %d, min is %d"
,
maxval
,
minval
));
auto
stream
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>()
.
stream
();
dev_ctx
.
stream
();
KernelHistogram
<
T
,
IndexType
><<<
GET_BLOCKS
(
input_numel
),
PADDLE_CUDA_NUM_THREADS
,
nbins
*
sizeof
(
int64_t
),
stream
>>>
(
input_data
,
input_numel
,
nbins
,
output_min
,
output_max
,
out_data
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
histogram
,
ops
::
HistogramCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
HistogramCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
HistogramCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
HistogramCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
}
}
//namespace pten
PT_REGISTER_KERNEL
(
histogram
,
GPU
,
ALL_LAYOUT
,
pten
::
HistogramKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/pten/kernels/histogram_kernel.h
0 → 100644
浏览文件 @
5076ad87
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace
pten
{
template
<
typename
T
,
typename
Context
>
void
HistogramSelectKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
int64_t
bins
,
int
min
,
int
max
,
DenseTensor
*
out
);
}
// namspace pten
paddle/pten/ops/compat/histogram_sig.cc
0 → 100644
浏览文件 @
5076ad87
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
KernelSignature
HistogramOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"histogram"
,
{
"X"
},
{
"bins"
,
"min"
,
"max"
},
{
"Out"
});
}
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
histogram
,
pten
::
HistogramOpArgumentMapping
);
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录