Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
158bf13f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
158bf13f
编写于
1月 13, 2022
作者:
C
Chen Weihang
提交者:
GitHub
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Rename kernel register marco (#38861)
* rename register marco * fix error changing * fix format error
上级
dccdc719
变更
25
展开全部
显示空白变更内容
内联
并排
Showing
25 changed file
with
636 addition
and
1193 deletion
+636
-1193
cmake/pten_kernel.cmake
cmake/pten_kernel.cmake
+3
-3
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+132
-688
paddle/pten/kernels/cpu/cast_kernel.cc
paddle/pten/kernels/cpu/cast_kernel.cc
+15
-15
paddle/pten/kernels/cpu/complex_kernel.cc
paddle/pten/kernels/cpu/complex_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_grad_kernel.cc
paddle/pten/kernels/cpu/dot_grad_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_kernel.cc
paddle/pten/kernels/cpu/dot_kernel.cc
+10
-10
paddle/pten/kernels/cpu/full_kernel.cc
paddle/pten/kernels/cpu/full_kernel.cc
+25
-25
paddle/pten/kernels/cpu/math_kernel.cc
paddle/pten/kernels/cpu/math_kernel.cc
+54
-54
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
+26
-26
paddle/pten/kernels/cpu/matmul_kernel.cc
paddle/pten/kernels/cpu/matmul_kernel.cc
+8
-8
paddle/pten/kernels/cpu/scale_kernel.cc
paddle/pten/kernels/cpu/scale_kernel.cc
+12
-12
paddle/pten/kernels/cpu/sign_kernel.cc
paddle/pten/kernels/cpu/sign_kernel.cc
+1
-2
paddle/pten/kernels/empty_kernel.cc
paddle/pten/kernels/empty_kernel.cc
+58
-58
paddle/pten/kernels/flatten_grad_kernel.cc
paddle/pten/kernels/flatten_grad_kernel.cc
+30
-30
paddle/pten/kernels/flatten_kernel.cc
paddle/pten/kernels/flatten_kernel.cc
+60
-60
paddle/pten/kernels/gpu/cast_kernel.cu
paddle/pten/kernels/gpu/cast_kernel.cu
+18
-18
paddle/pten/kernels/gpu/complex_kernel.cu
paddle/pten/kernels/gpu/complex_kernel.cu
+11
-11
paddle/pten/kernels/gpu/dot_grad_kernel.cu
paddle/pten/kernels/gpu/dot_grad_kernel.cu
+10
-10
paddle/pten/kernels/gpu/dot_kernel.cu
paddle/pten/kernels/gpu/dot_kernel.cu
+10
-10
paddle/pten/kernels/gpu/full_kernel.cu
paddle/pten/kernels/gpu/full_kernel.cu
+24
-24
paddle/pten/kernels/gpu/math_kernel.cu
paddle/pten/kernels/gpu/math_kernel.cu
+58
-58
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
+29
-29
paddle/pten/kernels/gpu/matmul_kernel.cu
paddle/pten/kernels/gpu/matmul_kernel.cu
+9
-9
paddle/pten/kernels/gpu/scale_kernel.cu
paddle/pten/kernels/gpu/scale_kernel.cu
+12
-12
paddle/pten/kernels/gpu/sign_kernel.cu
paddle/pten/kernels/gpu/sign_kernel.cu
+1
-1
未找到文件。
cmake/pten_kernel.cmake
浏览文件 @
158bf13f
...
@@ -16,12 +16,12 @@
...
@@ -16,12 +16,12 @@
function
(
kernel_declare TARGET_LIST
)
function
(
kernel_declare TARGET_LIST
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
file
(
READ
${
kernel_path
}
kernel_impl
)
file
(
READ
${
kernel_path
}
kernel_impl
)
# TODO(chenweihang): rename PT_REGISTER_
CTX_
KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string
(
REGEX MATCH
"(PT_REGISTER_
CTX_
KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
string
(
REGEX MATCH
"(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
if
(
NOT first_registry STREQUAL
""
)
if
(
NOT first_registry STREQUAL
""
)
# parse the first kernel name
# parse the first kernel name
string
(
REPLACE
"PT_REGISTER_
CTX_
KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
158bf13f
此差异已折叠。
点击以展开。
paddle/pten/kernels/cpu/cast_kernel.cc
浏览文件 @
158bf13f
...
@@ -58,7 +58,7 @@ void CastKernel(const Context& dev_ctx,
...
@@ -58,7 +58,7 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
cast
,
PT_REGISTER_KERNEL
(
cast
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
CastKernel
,
pten
::
CastKernel
,
...
...
paddle/pten/kernels/cpu/complex_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
...
...
paddle/pten/kernels/cpu/dot_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
...
...
paddle/pten/kernels/cpu/dot_kernel.cc
浏览文件 @
158bf13f
...
@@ -49,7 +49,7 @@ void DotKernel(const Context& dev_ctx,
...
@@ -49,7 +49,7 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
...
...
paddle/pten/kernels/cpu/full_kernel.cc
浏览文件 @
158bf13f
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
...
@@ -34,7 +34,7 @@ PT_REGISTER_CTX_KERNEL(full,
...
@@ -34,7 +34,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
...
...
paddle/pten/kernels/cpu/math_kernel.cc
浏览文件 @
158bf13f
...
@@ -118,9 +118,9 @@ using complex128 = ::paddle::platform::complex<double>;
...
@@ -118,9 +118,9 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
...
@@ -130,7 +130,7 @@ PT_REGISTER_CTX_KERNEL(add,
...
@@ -130,7 +130,7 @@ PT_REGISTER_CTX_KERNEL(add,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
...
@@ -140,7 +140,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
...
@@ -140,7 +140,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
...
@@ -150,7 +150,7 @@ PT_REGISTER_CTX_KERNEL(divide,
...
@@ -150,7 +150,7 @@ PT_REGISTER_CTX_KERNEL(divide,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
...
@@ -161,7 +161,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
...
@@ -161,7 +161,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
bool
,
bool
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
...
...
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
...
@@ -28,7 +28,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
...
@@ -28,7 +28,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
...
@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
...
@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
...
...
paddle/pten/kernels/cpu/matmul_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
...
...
paddle/pten/kernels/cpu/scale_kernel.cc
浏览文件 @
158bf13f
...
@@ -51,7 +51,7 @@ void ScaleKernel(const Context& dev_ctx,
...
@@ -51,7 +51,7 @@ void ScaleKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
...
...
paddle/pten/kernels/cpu/sign_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,5 +21,4 @@ limitations under the License. */
...
@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{
PT_REGISTER_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{}
}
paddle/pten/kernels/empty_kernel.cc
浏览文件 @
158bf13f
...
@@ -34,7 +34,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
...
@@ -34,7 +34,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
...
@@ -50,7 +50,7 @@ PT_REGISTER_CTX_KERNEL(empty,
...
@@ -50,7 +50,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
...
@@ -67,7 +67,7 @@ PT_REGISTER_CTX_KERNEL(empty_like,
...
@@ -67,7 +67,7 @@ PT_REGISTER_CTX_KERNEL(empty_like,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
...
@@ -82,7 +82,7 @@ PT_REGISTER_CTX_KERNEL(empty,
...
@@ -82,7 +82,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
...
...
paddle/pten/kernels/flatten_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -33,7 +33,7 @@ void FlattenGradKernel(const Context& dev_ctx,
...
@@ -33,7 +33,7 @@ void FlattenGradKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
...
@@ -45,7 +45,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
...
@@ -45,7 +45,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
...
@@ -60,7 +60,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
...
@@ -60,7 +60,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
...
...
paddle/pten/kernels/flatten_kernel.cc
浏览文件 @
158bf13f
...
@@ -48,7 +48,7 @@ void FlattenWithXShape(const Context& dev_ctx,
...
@@ -48,7 +48,7 @@ void FlattenWithXShape(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
...
@@ -59,7 +59,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
...
@@ -59,7 +59,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
...
@@ -71,7 +71,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
...
@@ -71,7 +71,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
...
@@ -83,7 +83,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
...
@@ -83,7 +83,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
...
@@ -97,7 +97,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
...
@@ -97,7 +97,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
...
@@ -107,7 +107,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
...
@@ -107,7 +107,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
...
...
paddle/pten/kernels/gpu/cast_kernel.cu
浏览文件 @
158bf13f
...
@@ -61,7 +61,7 @@ void CastKernel(const Context& dev_ctx,
...
@@ -61,7 +61,7 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_
CTX_
KERNEL(cast, \
PT_REGISTER_KERNEL(cast, \
GPU, \
GPU, \
ALL_LAYOUT, \
ALL_LAYOUT, \
pten::CastKernel, \
pten::CastKernel, \
...
...
paddle/pten/kernels/gpu/complex_kernel.cu
浏览文件 @
158bf13f
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
...
...
paddle/pten/kernels/gpu/dot_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
...
...
paddle/pten/kernels/gpu/dot_kernel.cu
浏览文件 @
158bf13f
...
@@ -52,7 +52,7 @@ void DotKernel(const Context& dev_ctx,
...
@@ -52,7 +52,7 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
...
...
paddle/pten/kernels/gpu/full_kernel.cu
浏览文件 @
158bf13f
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
...
@@ -33,7 +33,7 @@ PT_REGISTER_CTX_KERNEL(full,
...
@@ -33,7 +33,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
...
...
paddle/pten/kernels/gpu/math_kernel.cu
浏览文件 @
158bf13f
...
@@ -110,9 +110,9 @@ using float16 = paddle::platform::float16;
...
@@ -110,9 +110,9 @@ using float16 = paddle::platform::float16;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
...
@@ -123,7 +123,7 @@ PT_REGISTER_CTX_KERNEL(add,
...
@@ -123,7 +123,7 @@ PT_REGISTER_CTX_KERNEL(add,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
...
@@ -134,7 +134,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
...
@@ -134,7 +134,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
...
@@ -145,7 +145,7 @@ PT_REGISTER_CTX_KERNEL(divide,
...
@@ -145,7 +145,7 @@ PT_REGISTER_CTX_KERNEL(divide,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
...
@@ -157,7 +157,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
...
@@ -157,7 +157,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
...
...
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
...
@@ -29,7 +29,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
...
@@ -29,7 +29,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
...
@@ -39,7 +39,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
...
@@ -39,7 +39,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
...
...
paddle/pten/kernels/gpu/matmul_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
...
...
paddle/pten/kernels/gpu/scale_kernel.cu
浏览文件 @
158bf13f
...
@@ -64,7 +64,7 @@ void ScaleKernel(const ContextT& dev_ctx,
...
@@ -64,7 +64,7 @@ void ScaleKernel(const ContextT& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
...
...
paddle/pten/kernels/gpu/sign_kernel.cu
浏览文件 @
158bf13f
...
@@ -23,5 +23,5 @@ limitations under the License. */
...
@@ -23,5 +23,5 @@ limitations under the License. */
using
float16
=
paddle
::
platform
::
float16
;
using
float16
=
paddle
::
platform
::
float16
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录