Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
158bf13f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
158bf13f
编写于
1月 13, 2022
作者:
C
Chen Weihang
提交者:
GitHub
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTen] Rename kernel register marco (#38861)
* rename register marco * fix error changing * fix format error
上级
dccdc719
变更
25
展开全部
显示空白变更内容
内联
并排
Showing
25 changed file
with
636 addition
and
1193 deletion
+636
-1193
cmake/pten_kernel.cmake
cmake/pten_kernel.cmake
+3
-3
paddle/pten/core/kernel_registry.h
paddle/pten/core/kernel_registry.h
+132
-688
paddle/pten/kernels/cpu/cast_kernel.cc
paddle/pten/kernels/cpu/cast_kernel.cc
+15
-15
paddle/pten/kernels/cpu/complex_kernel.cc
paddle/pten/kernels/cpu/complex_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_grad_kernel.cc
paddle/pten/kernels/cpu/dot_grad_kernel.cc
+10
-10
paddle/pten/kernels/cpu/dot_kernel.cc
paddle/pten/kernels/cpu/dot_kernel.cc
+10
-10
paddle/pten/kernels/cpu/full_kernel.cc
paddle/pten/kernels/cpu/full_kernel.cc
+25
-25
paddle/pten/kernels/cpu/math_kernel.cc
paddle/pten/kernels/cpu/math_kernel.cc
+54
-54
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
+26
-26
paddle/pten/kernels/cpu/matmul_kernel.cc
paddle/pten/kernels/cpu/matmul_kernel.cc
+8
-8
paddle/pten/kernels/cpu/scale_kernel.cc
paddle/pten/kernels/cpu/scale_kernel.cc
+12
-12
paddle/pten/kernels/cpu/sign_kernel.cc
paddle/pten/kernels/cpu/sign_kernel.cc
+1
-2
paddle/pten/kernels/empty_kernel.cc
paddle/pten/kernels/empty_kernel.cc
+58
-58
paddle/pten/kernels/flatten_grad_kernel.cc
paddle/pten/kernels/flatten_grad_kernel.cc
+30
-30
paddle/pten/kernels/flatten_kernel.cc
paddle/pten/kernels/flatten_kernel.cc
+60
-60
paddle/pten/kernels/gpu/cast_kernel.cu
paddle/pten/kernels/gpu/cast_kernel.cu
+18
-18
paddle/pten/kernels/gpu/complex_kernel.cu
paddle/pten/kernels/gpu/complex_kernel.cu
+11
-11
paddle/pten/kernels/gpu/dot_grad_kernel.cu
paddle/pten/kernels/gpu/dot_grad_kernel.cu
+10
-10
paddle/pten/kernels/gpu/dot_kernel.cu
paddle/pten/kernels/gpu/dot_kernel.cu
+10
-10
paddle/pten/kernels/gpu/full_kernel.cu
paddle/pten/kernels/gpu/full_kernel.cu
+24
-24
paddle/pten/kernels/gpu/math_kernel.cu
paddle/pten/kernels/gpu/math_kernel.cu
+58
-58
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
+29
-29
paddle/pten/kernels/gpu/matmul_kernel.cu
paddle/pten/kernels/gpu/matmul_kernel.cu
+9
-9
paddle/pten/kernels/gpu/scale_kernel.cu
paddle/pten/kernels/gpu/scale_kernel.cu
+12
-12
paddle/pten/kernels/gpu/sign_kernel.cu
paddle/pten/kernels/gpu/sign_kernel.cu
+1
-1
未找到文件。
cmake/pten_kernel.cmake
浏览文件 @
158bf13f
...
@@ -16,12 +16,12 @@
...
@@ -16,12 +16,12 @@
function
(
kernel_declare TARGET_LIST
)
function
(
kernel_declare TARGET_LIST
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
foreach
(
kernel_path
${
TARGET_LIST
}
)
file
(
READ
${
kernel_path
}
kernel_impl
)
file
(
READ
${
kernel_path
}
kernel_impl
)
# TODO(chenweihang): rename PT_REGISTER_
CTX_
KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string
(
REGEX MATCH
"(PT_REGISTER_
CTX_
KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
string
(
REGEX MATCH
"(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)
\\
([
\t\r\n
]*[a-z0-9_]*,"
first_registry
"
${
kernel_impl
}
"
)
if
(
NOT first_registry STREQUAL
""
)
if
(
NOT first_registry STREQUAL
""
)
# parse the first kernel name
# parse the first kernel name
string
(
REPLACE
"PT_REGISTER_
CTX_
KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_KERNEL("
""
kernel_name
"
${
first_registry
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
"PT_REGISTER_GENERAL_KERNEL("
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REPLACE
","
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
string
(
REGEX REPLACE
"[
\t\r\n
]+"
""
kernel_name
"
${
kernel_name
}
"
)
...
...
paddle/pten/core/kernel_registry.h
浏览文件 @
158bf13f
此差异已折叠。
点击以展开。
paddle/pten/kernels/cpu/cast_kernel.cc
浏览文件 @
158bf13f
...
@@ -58,7 +58,7 @@ void CastKernel(const Context& dev_ctx,
...
@@ -58,7 +58,7 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
cast
,
PT_REGISTER_KERNEL
(
cast
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
CastKernel
,
pten
::
CastKernel
,
...
...
paddle/pten/kernels/cpu/complex_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
...
...
paddle/pten/kernels/cpu/dot_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
...
...
paddle/pten/kernels/cpu/dot_kernel.cc
浏览文件 @
158bf13f
...
@@ -49,7 +49,7 @@ void DotKernel(const Context& dev_ctx,
...
@@ -49,7 +49,7 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
...
...
paddle/pten/kernels/cpu/full_kernel.cc
浏览文件 @
158bf13f
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
...
@@ -34,7 +34,7 @@ PT_REGISTER_CTX_KERNEL(full,
...
@@ -34,7 +34,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
...
...
paddle/pten/kernels/cpu/math_kernel.cc
浏览文件 @
158bf13f
...
@@ -118,9 +118,9 @@ using complex128 = ::paddle::platform::complex<double>;
...
@@ -118,9 +118,9 @@ using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
mean
,
CPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
...
@@ -130,7 +130,7 @@ PT_REGISTER_CTX_KERNEL(add,
...
@@ -130,7 +130,7 @@ PT_REGISTER_CTX_KERNEL(add,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
...
@@ -140,7 +140,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
...
@@ -140,7 +140,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
...
@@ -150,7 +150,7 @@ PT_REGISTER_CTX_KERNEL(divide,
...
@@ -150,7 +150,7 @@ PT_REGISTER_CTX_KERNEL(divide,
int64_t
,
int64_t
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
...
@@ -161,7 +161,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
...
@@ -161,7 +161,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
bool
,
bool
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
...
...
paddle/pten/kernels/cpu/matmul_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
...
@@ -28,7 +28,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
...
@@ -28,7 +28,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
...
@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
...
@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
...
...
paddle/pten/kernels/cpu/matmul_kernel.cc
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
...
...
paddle/pten/kernels/cpu/scale_kernel.cc
浏览文件 @
158bf13f
...
@@ -51,7 +51,7 @@ void ScaleKernel(const Context& dev_ctx,
...
@@ -51,7 +51,7 @@ void ScaleKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
...
...
paddle/pten/kernels/cpu/sign_kernel.cc
浏览文件 @
158bf13f
...
@@ -21,5 +21,4 @@ limitations under the License. */
...
@@ -21,5 +21,4 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/bfloat16.h"
PT_REGISTER_CTX_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{
PT_REGISTER_KERNEL
(
sign
,
CPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
)
{}
}
paddle/pten/kernels/empty_kernel.cc
浏览文件 @
158bf13f
...
@@ -34,7 +34,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
...
@@ -34,7 +34,7 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
...
@@ -50,7 +50,7 @@ PT_REGISTER_CTX_KERNEL(empty,
...
@@ -50,7 +50,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
...
@@ -67,7 +67,7 @@ PT_REGISTER_CTX_KERNEL(empty_like,
...
@@ -67,7 +67,7 @@ PT_REGISTER_CTX_KERNEL(empty_like,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
empty
,
PT_REGISTER_KERNEL
(
empty
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyKernel
,
pten
::
EmptyKernel
,
...
@@ -82,7 +82,7 @@ PT_REGISTER_CTX_KERNEL(empty,
...
@@ -82,7 +82,7 @@ PT_REGISTER_CTX_KERNEL(empty,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
empty_like
,
PT_REGISTER_KERNEL
(
empty_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
EmptyLikeKernel
,
pten
::
EmptyLikeKernel
,
...
...
paddle/pten/kernels/flatten_grad_kernel.cc
浏览文件 @
158bf13f
...
@@ -33,7 +33,7 @@ void FlattenGradKernel(const Context& dev_ctx,
...
@@ -33,7 +33,7 @@ void FlattenGradKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
...
@@ -45,7 +45,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
...
@@ -45,7 +45,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
...
@@ -60,7 +60,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
...
@@ -60,7 +60,7 @@ PT_REGISTER_CTX_KERNEL(flatten_grad,
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten_grad
,
PT_REGISTER_KERNEL
(
flatten_grad
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenGradKernel
,
pten
::
FlattenGradKernel
,
...
...
paddle/pten/kernels/flatten_kernel.cc
浏览文件 @
158bf13f
...
@@ -48,7 +48,7 @@ void FlattenWithXShape(const Context& dev_ctx,
...
@@ -48,7 +48,7 @@ void FlattenWithXShape(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
...
@@ -59,7 +59,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
...
@@ -59,7 +59,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
...
@@ -71,7 +71,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
...
@@ -71,7 +71,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
int64_t
)
{}
int64_t
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
...
@@ -83,7 +83,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
...
@@ -83,7 +83,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
...
@@ -97,7 +97,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
...
@@ -97,7 +97,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
PT_REGISTER_
CTX_
KERNEL
(
flatten
,
PT_REGISTER_KERNEL
(
flatten
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenKernel
,
pten
::
FlattenKernel
,
...
@@ -107,7 +107,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
...
@@ -107,7 +107,7 @@ PT_REGISTER_CTX_KERNEL(flatten,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
PT_REGISTER_
CTX_
KERNEL
(
flatten_with_xshape
,
PT_REGISTER_KERNEL
(
flatten_with_xshape
,
XPU
,
XPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FlattenWithXShape
,
pten
::
FlattenWithXShape
,
...
...
paddle/pten/kernels/gpu/cast_kernel.cu
浏览文件 @
158bf13f
...
@@ -61,7 +61,7 @@ void CastKernel(const Context& dev_ctx,
...
@@ -61,7 +61,7 @@ void CastKernel(const Context& dev_ctx,
}
// namespace pten
}
// namespace pten
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_
CTX_
KERNEL(cast, \
PT_REGISTER_KERNEL(cast, \
GPU, \
GPU, \
ALL_LAYOUT, \
ALL_LAYOUT, \
pten::CastKernel, \
pten::CastKernel, \
...
...
paddle/pten/kernels/gpu/complex_kernel.cu
浏览文件 @
158bf13f
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
// See Note [ Why still include the fluid headers? ]
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
conj
,
PT_REGISTER_KERNEL
(
conj
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ConjKernel
,
pten
::
ConjKernel
,
...
...
paddle/pten/kernels/gpu/dot_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_
CTX_
KERNEL
(
dot_grad
,
PT_REGISTER_KERNEL
(
dot_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotGradKernel
,
pten
::
DotGradKernel
,
...
...
paddle/pten/kernels/gpu/dot_kernel.cu
浏览文件 @
158bf13f
...
@@ -52,7 +52,7 @@ void DotKernel(const Context& dev_ctx,
...
@@ -52,7 +52,7 @@ void DotKernel(const Context& dev_ctx,
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
dot
,
PT_REGISTER_KERNEL
(
dot
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DotKernel
,
pten
::
DotKernel
,
...
...
paddle/pten/kernels/gpu/full_kernel.cu
浏览文件 @
158bf13f
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
full
,
PT_REGISTER_KERNEL
(
full
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullKernel
,
pten
::
FullKernel
,
...
@@ -33,7 +33,7 @@ PT_REGISTER_CTX_KERNEL(full,
...
@@ -33,7 +33,7 @@ PT_REGISTER_CTX_KERNEL(full,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
full_like
,
PT_REGISTER_KERNEL
(
full_like
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
FullLikeKernel
,
pten
::
FullLikeKernel
,
...
...
paddle/pten/kernels/gpu/math_kernel.cu
浏览文件 @
158bf13f
...
@@ -110,9 +110,9 @@ using float16 = paddle::platform::float16;
...
@@ -110,9 +110,9 @@ using float16 = paddle::platform::float16;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex64
=
::
paddle
::
platform
::
complex
<
float
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
using
complex128
=
::
paddle
::
platform
::
complex
<
double
>
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
mean
,
GPU
,
ALL_LAYOUT
,
pten
::
MeanKernel
,
float
,
double
,
bool
,
float16
)
{}
PT_REGISTER_
CTX_
KERNEL
(
add
,
PT_REGISTER_KERNEL
(
add
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
AddKernel
,
pten
::
AddKernel
,
...
@@ -123,7 +123,7 @@ PT_REGISTER_CTX_KERNEL(add,
...
@@ -123,7 +123,7 @@ PT_REGISTER_CTX_KERNEL(add,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
subtract
,
PT_REGISTER_KERNEL
(
subtract
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SubtractKernel
,
pten
::
SubtractKernel
,
...
@@ -134,7 +134,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
...
@@ -134,7 +134,7 @@ PT_REGISTER_CTX_KERNEL(subtract,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
divide
,
PT_REGISTER_KERNEL
(
divide
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
DivideKernel
,
pten
::
DivideKernel
,
...
@@ -145,7 +145,7 @@ PT_REGISTER_CTX_KERNEL(divide,
...
@@ -145,7 +145,7 @@ PT_REGISTER_CTX_KERNEL(divide,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
multiply
,
PT_REGISTER_KERNEL
(
multiply
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MultiplyKernel
,
pten
::
MultiplyKernel
,
...
@@ -157,7 +157,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
...
@@ -157,7 +157,7 @@ PT_REGISTER_CTX_KERNEL(multiply,
float16
,
float16
,
complex64
,
complex64
,
complex128
)
{}
complex128
)
{}
PT_REGISTER_
CTX_
KERNEL
(
sum
,
PT_REGISTER_KERNEL
(
sum
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
SumKernel
,
pten
::
SumKernel
,
...
...
paddle/pten/kernels/gpu/matmul_grad_kernel.cu
浏览文件 @
158bf13f
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul_grad
,
PT_REGISTER_KERNEL
(
matmul_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulGradKernel
,
pten
::
MatmulGradKernel
,
...
@@ -29,7 +29,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
...
@@ -29,7 +29,7 @@ PT_REGISTER_CTX_KERNEL(matmul_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_double_grad
,
PT_REGISTER_KERNEL
(
matmul_double_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulDoubleGradKernel
,
pten
::
MatmulDoubleGradKernel
,
...
@@ -39,7 +39,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
...
@@ -39,7 +39,7 @@ PT_REGISTER_CTX_KERNEL(matmul_double_grad,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle
::
platform
::
complex
<
double
>
)
{}
PT_REGISTER_
CTX_
KERNEL
(
matmul_triple_grad
,
PT_REGISTER_KERNEL
(
matmul_triple_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulTripleGradKernel
,
pten
::
MatmulTripleGradKernel
,
...
...
paddle/pten/kernels/gpu/matmul_kernel.cu
浏览文件 @
158bf13f
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
PT_REGISTER_
CTX_
KERNEL
(
matmul
,
PT_REGISTER_KERNEL
(
matmul
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
MatmulKernel
,
pten
::
MatmulKernel
,
...
...
paddle/pten/kernels/gpu/scale_kernel.cu
浏览文件 @
158bf13f
...
@@ -64,7 +64,7 @@ void ScaleKernel(const ContextT& dev_ctx,
...
@@ -64,7 +64,7 @@ void ScaleKernel(const ContextT& dev_ctx,
}
// namespace pten
}
// namespace pten
PT_REGISTER_
CTX_
KERNEL
(
scale
,
PT_REGISTER_KERNEL
(
scale
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
pten
::
ScaleKernel
,
pten
::
ScaleKernel
,
...
...
paddle/pten/kernels/gpu/sign_kernel.cu
浏览文件 @
158bf13f
...
@@ -23,5 +23,5 @@ limitations under the License. */
...
@@ -23,5 +23,5 @@ limitations under the License. */
using
float16
=
paddle
::
platform
::
float16
;
using
float16
=
paddle
::
platform
::
float16
;
PT_REGISTER_
CTX_
KERNEL
(
PT_REGISTER_KERNEL
(
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
sign
,
GPU
,
ALL_LAYOUT
,
pten
::
SignKernel
,
float
,
double
,
float16
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录