Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
158bf13f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
158bf13f
编写于
1月 13, 2022
作者:
C
Chen Weihang
提交者:
GitHub
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Rename kernel register marco (#38861)
* rename register marco * fix error changing * fix format error
上级
dccdc719
变更
25
展开全部
隐藏空白更改
内联
并排
Showing
25 changed file
with
636 addition
and
1193 deletion
+636
-1193
cmake/pten_kernel.cmake
cmake/pten_kernel.cmake
+3
-3
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+132
-688
paddle/pten/kernels/cpu/cast_kernel.cc
paddle/pten/kernels/cpu/cast_kernel.cc
+15
-15
paddle/pten/kernels/cpu/complex_kernel.cc
paddle/pten/kernels/cpu/complex_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_grad_kernel.cc
paddle/pten/kernels/cpu/dot_grad_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_kernel.cc
paddle/pten/kernels/cpu/dot_kernel.cc
+10
-10
paddle/pten/kernels/cpu/full_kernel.cc
paddle/pten/kernels/cpu/full_kernel.cc
+25
-25
paddle/pten/kernels/cpu/math_kernel.cc
paddle/pten/kernels/cpu/math_kernel.cc
+54
-54
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
+26
-26
paddle/pten/kernels/cpu/matmul_kernel.cc
paddle/pten/kernels/cpu/matmul_kernel.cc
+8
-8
paddle/pten/kernels/cpu/scale_kernel.cc
paddle/pten/kernels/cpu/scale_kernel.cc
+12
-12
paddle/pten/kernels/cpu/sign_kernel.cc
paddle/pten/kernels/cpu/sign_kernel.cc
+1
-2
paddle/pten/kernels/empty_kernel.cc
paddle/pten/kernels/empty_kernel.cc
+58
-58
paddle/pten/kernels/flatten_grad_kernel.cc
paddle/pten/kernels/flatten_grad_kernel.cc
+30
-30
paddle/pten/kernels/flatten_kernel.cc
paddle/pten/kernels/flatten_kernel.cc
+60
-60
paddle/pten/kernels/gpu/cast_kernel.cu
paddle/pten/kernels/gpu/cast_kernel.cu
+18
-18
paddle/pten/kernels/gpu/complex_kernel.cu
paddle/pten/kernels/gpu/complex_kernel.cu
+11
-11
paddle/pten/kernels/gpu/dot_grad_kernel.cu
paddle/pten/kernels/gpu/dot_grad_kernel.cu
+10
-10
paddle/pten/kernels/gpu/dot_kernel.cu
paddle/pten/kernels/gpu/dot_kernel.cu
+10
-10
paddle/pten/kernels/gpu/full_kernel.cu
paddle/pten/kernels/gpu/full_kernel.cu
+24
-24
paddle/pten/kernels/gpu/math_kernel.cu
paddle/pten/kernels/gpu/math_kernel.cu
+58
-58
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
+29
-29
paddle/pten/kernels/gpu/matmul_kernel.cu
paddle/pten/kernels/gpu/matmul_kernel.cu
+9
-9
paddle/pten/kernels/gpu/scale_kernel.cu
paddle/pten/kernels/gpu/scale_kernel.cu
+12
-12
paddle/pten/kernels/gpu/sign_kernel.cu
paddle/pten/kernels/gpu/sign_kernel.cu
+1
-1
未找到文件。
cmake/pten_kernel.cmake
浏览文件 @
158bf13f
...
@@ -16,12 +16,12 @@
...
@@ -16,12 +16,12 @@
function
(
kernel_declare TARGET_LIST
)
function
(
kernel_declare TARGET_LIST
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
file
(
READ
${
kernel_path
}
kernel_impl
)
file
(
READ
${
kernel_path
}
kernel_impl
)
# TODO(chenweihang): rename PT_REGISTER_
CTX_
KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string
(
REGEX MATCH
"(PT_REGISTER_
CTX_
KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
string
(
REGEX MATCH
"(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
if
(
NOT first_registry STREQUAL
""
)
if
(
NOT first_registry STREQUAL
""
)
# parse the first kernel name
# parse the first kernel name
string
(
REPLACE
"PT_REGISTER_
CTX_
KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
158bf13f
此差异已折叠。
点击以展开。
paddle/pten/kernels/cpu/cast_kernel.cc
浏览文件 @
158bf13f
...
@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,
...
@@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
cast
,
PT_REGISTER_KERNEL
(
cast
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
CastKernel
,
pten
::
CastKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
int16_t
,
int16_t
,
bool
,
bool
,
uint8_t
,
uint8_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{
paddle
::
platform
::
complex
<
double
>
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
}
paddle/pten/kernels/cpu/complex_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,13 +21,13 @@
...
@@ -21,13 +21,13 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/cpu/dot_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,13 +20,13 @@
...
@@ -20,13 +20,13 @@
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/dot_kernel.cc
浏览文件 @
158bf13f
...
@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
...
@@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
paddle/pten/kernels/cpu/full_kernel.cc
浏览文件 @
158bf13f
...
@@ -18,29 +18,29 @@ limitations under the License. */
...
@@ -18,29 +18,29 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
)
{}
paddle
::
platform
::
float16
)
{}
paddle/pten/kernels/cpu/math_kernel.cc
浏览文件 @
158bf13f
...
@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;
...
@@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
bool
,
bool
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{
complex128
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
}
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -19,29 +19,29 @@ limitations under the License. */
...
@@ -19,29 +19,29 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/matmul_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,11 +20,11 @@ limitations under the License. */
...
@@ -20,11 +20,11 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/cpu/scale_kernel.cc
浏览文件 @
158bf13f
...
@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,
...
@@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/cpu/sign_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,5 +21,4 @@ limitations under the License. */
...
@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{
PT_REGISTER_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{}
}
paddle/pten/kernels/empty_kernel.cc
浏览文件 @
158bf13f
...
@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
...
@@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
#endif
#endif
paddle/pten/kernels/flatten_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx,
...
@@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
paddle/pten/kernels/flatten_kernel.cc
浏览文件 @
158bf13f
...
@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx,
...
@@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
double
,
double
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
float
,
float
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
int8_t
,
int8_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
#endif
#endif
paddle/pten/kernels/gpu/cast_kernel.cu
浏览文件 @
158bf13f
...
@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx,
...
@@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...)
\
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_
CTX_
KERNEL(cast, \
PT_REGISTER_KERNEL(cast, \
GPU, \
GPU, \
ALL_LAYOUT, \
ALL_LAYOUT, \
pten::CastKernel, \
pten::CastKernel, \
float, \
float, \
double, \
double, \
int, \
int, \
int64_t, \
int64_t, \
int16_t, \
int16_t, \
bool, \
bool, \
uint8_t, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType(
\
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED);
\
paddle::experimental::DataType::UNDEFINED); \
}
}
#if !defined(PADDLE_WITH_HIP)
#if !defined(PADDLE_WITH_HIP)
...
...
paddle/pten/kernels/gpu/complex_kernel.cu
浏览文件 @
158bf13f
...
@@ -21,14 +21,14 @@
...
@@ -21,14 +21,14 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
,
paddle
::
platform
::
complex
<
double
>
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/gpu/dot_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,13 +20,13 @@ limitations under the License. */
...
@@ -20,13 +20,13 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/dot_kernel.cu
浏览文件 @
158bf13f
...
@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx,
...
@@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
paddle/pten/kernels/gpu/full_kernel.cu
浏览文件 @
158bf13f
...
@@ -18,28 +18,28 @@ limitations under the License. */
...
@@ -18,28 +18,28 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
float
,
float
,
double
,
double
,
uint8_t
,
uint8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
paddle
::
platform
::
float16
)
{}
paddle
::
platform
::
float16
)
{}
paddle/pten/kernels/gpu/math_kernel.cu
浏览文件 @
158bf13f
...
@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16;
...
@@ -110,64 +110,64 @@ using float16 = paddle::platform::float16;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
,
int64_t
,
bool
,
bool
,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
bool
,
bool
,
float
,
float
,
double
,
double
,
float16
,
float16
,
int
,
int
,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{
complex128
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
}
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -19,32 +19,32 @@ limitations under the License. */
...
@@ -19,32 +19,32 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/matmul_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,12 +20,12 @@ limitations under the License. */
...
@@ -20,12 +20,12 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/gpu/scale_kernel.cu
浏览文件 @
158bf13f
...
@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx,
...
@@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
float
,
float
,
double
,
double
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
float16
,
uint8_t
,
uint8_t
,
int8_t
,
int8_t
,
int16_t
,
int16_t
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
paddle/pten/kernels/gpu/sign_kernel.cu
浏览文件 @
158bf13f
...
@@ -23,5 +23,5 @@ limitations under the License. */
...
@@ -23,5 +23,5 @@ limitations under the License. */
using
float16
=
paddle
::
platform
::
float16
;
using
float16
=
paddle
::
platform
::
float16
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录