Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b007a031
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b007a031
编写于
2月 09, 2022
作者:
N
niuliling123
提交者:
GitHub
2月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Delete BASE_SIZE in elementwise_base.h (#39390)
上级
2be20e20
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
29 addition
and
28 deletion
+29
-28
paddle/pten/core/utils/array.h
paddle/pten/core/utils/array.h
+4
-10
paddle/pten/kernels/funcs/elementwise_base.h
paddle/pten/kernels/funcs/elementwise_base.h
+7
-11
paddle/pten/kernels/gpu/full_kernel.cu
paddle/pten/kernels/gpu/full_kernel.cu
+4
-5
paddle/pten/kernels/primitive/compute_primitives.h
paddle/pten/kernels/primitive/compute_primitives.h
+1
-2
paddle/pten/kernels/primitive/compute_primitives_xpu2.h
paddle/pten/kernels/primitive/compute_primitives_xpu2.h
+13
-0
未找到文件。
paddle/pten/core/utils/array.h
浏览文件 @
b007a031
...
...
@@ -104,28 +104,22 @@ class Array<T, 0> {
HOSTDEVICE
inline
T
*
GetMutable
()
{
return
nullptr
;
}
HOSTDEVICE
inline
T
&
operator
[](
size_t
)
{
#if defined(__HIPCC__)
// HIP will have compile error, if use "obj()"
#if defined(__HIPCC__)
|| defined(__CUDA_ARCH__)
// HIP
and CUDA
will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class
static
T
obj
{};
return
obj
;
#elif defined(__CUDA_ARCH__)
static
T
obj
();
return
obj
;
#else
PADDLE_THROW
(
pten
::
errors
::
Unavailable
(
"Array<T, 0> has no element."
));
#endif
}
HOSTDEVICE
inline
const
T
&
operator
[](
size_t
)
const
{
#if defined(__HIPCC__)
// HIP will have compile error, if use "obj()"
#if defined(__HIPCC__)
|| defined(__CUDA_ARCH__)
// HIP
and CUDA
will have compile error, if use "obj()"
// function declared in block scope cannot have 'static' storage class
static
const
T
obj
{};
return
obj
;
#elif defined(__CUDA_ARCH__)
static
const
T
obj
();
return
obj
;
#else
PADDLE_THROW
(
pten
::
errors
::
Unavailable
(
"Array<T, 0> has no element."
));
#endif
...
...
paddle/pten/kernels/funcs/elementwise_base.h
浏览文件 @
b007a031
...
...
@@ -31,8 +31,6 @@ namespace kps = pten::kps;
#endif
#define BASE_SIZE 1 // To avoid running errors when Arity == 0 in args[Arity]
namespace
pten
{
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
,
kTernary
=
3
,
kAny
=
-
1
};
...
...
@@ -482,7 +480,7 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
Elementwise
FillCons
t
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
kps
::
Elementwise
Constan
t
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
func
);
}
};
...
...
@@ -560,13 +558,12 @@ template <typename InT,
bool
IsBoundary
>
__device__
void
VectorizedElementwiseKernelImpl
(
const
pten
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
+
BASE_SIZE
>
&
in
,
const
pten
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
&
in
,
pten
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
num
,
int
data_offset
,
Functor
func
)
{
InT
args
[
Arity
+
BASE_SIZE
][
VecSize
];
InT
args
[
Arity
>
1
?
Arity
:
1
][
VecSize
];
ConditionalT
<
OutT
,
NumOuts
>
result
[
VecSize
];
#pragma unroll
...
...
@@ -596,8 +593,7 @@ template <typename InT,
int
NumOuts
,
int
VecSize
>
__global__
void
VectorizedElementwiseKernel
(
pten
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
+
BASE_SIZE
>
ins
,
pten
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins
,
pten
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs
,
int
size
,
int
main_offset
,
...
...
@@ -637,9 +633,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx,
const
std
::
vector
<
const
DenseTensor
*>
&
ins
,
std
::
vector
<
DenseTensor
*>
*
outs
,
Functor
func
)
{
auto
numel
=
(
*
outs
)[
0
]
->
numel
();
pten
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
+
BASE_SIZE
>
ins_data
;
auto
numel
=
(
*
outs
)[
0
]
->
numel
();
// To avoid running errors when ins.size()== 0
pten
::
framework
::
Array
<
const
_ptr_
InT
*
__restrict__
,
Arity
>
ins_data
;
pten
::
framework
::
Array
<
_ptr_
OutT
*
,
NumOuts
>
outs_data
;
for
(
int
i
=
0
;
i
<
Arity
;
++
i
)
{
...
...
paddle/pten/kernels/gpu/full_kernel.cu
浏览文件 @
b007a031
...
...
@@ -62,8 +62,7 @@ void FullLikeKernel(const ContextT& dev_ctx,
auto
value
=
val
.
to
<
float
>
();
using
CommonType
=
typename
std
::
common_type
<
float
,
typename
std
::
conditional
<
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
,
typename
std
::
conditional
<
std
::
is_same
<
T
,
pten
::
dtype
::
float16
>::
value
,
float
,
T
>::
type
>::
type
;
...
...
@@ -75,7 +74,7 @@ void FullLikeKernel(const ContextT& dev_ctx,
(
common_type_value
<=
static_cast
<
CommonType
>
(
std
::
numeric_limits
<
T
>::
max
())),
true
,
p
addle
::
platform
::
errors
::
InvalidArgument
(
p
ten
::
errors
::
InvalidArgument
(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
"and %f, but now value is %f."
,
...
...
paddle/pten/kernels/primitive/compute_primitives.h
浏览文件 @
b007a031
...
...
@@ -420,8 +420,7 @@ template <typename InT,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseFillConst
(
OutT
*
out
,
OpFunc
compute
)
{
__device__
__forceinline__
void
ElementwiseConstant
(
OutT
*
out
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
idx
++
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
());
...
...
paddle/pten/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
b007a031
...
...
@@ -348,5 +348,18 @@ __device__ __forceinline__ void Reduce(T* out,
}
}
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseConstant
(
OutT
*
out
,
OpFunc
compute
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
NX
*
NY
;
idx
++
)
{
out
[
idx
]
=
static_cast
<
OutT
>
(
compute
());
}
}
}
// namespace kps
}
// namespace pten
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录