Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5c3873f6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c3873f6
编写于
2月 08, 2022
作者:
S
sneaxiy
提交者:
GitHub
2月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add __PD_DEFINE_RAW_OP_KERNEL_FUNC for registering custom op kernel with ExecutionContext (#39352)
* hack custom op * add ut * skip windows ci
上级
fee4316d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
395 addition
and
37 deletion
+395
-37
paddle/fluid/framework/custom_operator.cc
paddle/fluid/framework/custom_operator.cc
+49
-35
paddle/fluid/framework/custom_operator.h
paddle/fluid/framework/custom_operator.h
+3
-2
paddle/fluid/framework/custom_raw_op_kernel_func.h
paddle/fluid/framework/custom_raw_op_kernel_func.h
+27
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+9
-0
python/paddle/fluid/tests/custom_op/CMakeLists.txt
python/paddle/fluid/tests/custom_op/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc
...n/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc
+52
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu
...n/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu
+21
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h
...on/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h
+84
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py
...le/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py
+50
-0
python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py
...dle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py
+97
-0
未找到文件。
paddle/fluid/framework/custom_operator.cc
浏览文件 @
5c3873f6
...
@@ -61,27 +61,27 @@ static T* DynLoad(void* handle, std::string name) {
...
@@ -61,27 +61,27 @@ static T* DynLoad(void* handle, std::string name) {
return
func
;
return
func
;
}
}
inline
bool
IsGradVar
(
const
std
::
string
&
var_name
)
{
inline
static
bool
IsGradVar
(
const
std
::
string
&
var_name
)
{
std
::
string
suffix
=
kGradVarSuffix
;
std
::
string
suffix
=
kGradVarSuffix
;
return
var_name
.
rfind
(
suffix
)
!=
std
::
string
::
npos
;
return
var_name
.
rfind
(
suffix
)
!=
std
::
string
::
npos
;
}
}
inline
bool
IsDuplicableVar
(
const
std
::
string
&
var_name
)
{
inline
static
bool
IsDuplicableVar
(
const
std
::
string
&
var_name
)
{
std
::
string
suffix
=
kTensorVectorSuffix
;
std
::
string
suffix
=
kTensorVectorSuffix
;
return
var_name
.
rfind
(
suffix
)
!=
std
::
string
::
npos
;
return
var_name
.
rfind
(
suffix
)
!=
std
::
string
::
npos
;
}
}
inline
std
::
string
NoGrad
(
const
std
::
string
&
var_name
)
{
inline
st
atic
st
d
::
string
NoGrad
(
const
std
::
string
&
var_name
)
{
std
::
string
suffix
=
kGradVarSuffix
;
std
::
string
suffix
=
kGradVarSuffix
;
return
var_name
.
substr
(
0
,
var_name
.
size
()
-
kGradVarSuffixSize
);
return
var_name
.
substr
(
0
,
var_name
.
size
()
-
kGradVarSuffixSize
);
}
}
inline
bool
IsMemberOf
(
const
std
::
vector
<
std
::
string
>&
vec
,
inline
static
bool
IsMemberOf
(
const
std
::
vector
<
std
::
string
>&
vec
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
return
std
::
find
(
vec
.
cbegin
(),
vec
.
cend
(),
name
)
!=
vec
.
cend
();
return
std
::
find
(
vec
.
cbegin
(),
vec
.
cend
(),
name
)
!=
vec
.
cend
();
}
}
std
::
vector
<
std
::
string
>
ParseAttrStr
(
const
std
::
string
&
attr
)
{
st
atic
st
d
::
vector
<
std
::
string
>
ParseAttrStr
(
const
std
::
string
&
attr
)
{
auto
split_pos
=
attr
.
find_first_of
(
":"
);
auto
split_pos
=
attr
.
find_first_of
(
":"
);
PADDLE_ENFORCE_NE
(
split_pos
,
std
::
string
::
npos
,
PADDLE_ENFORCE_NE
(
split_pos
,
std
::
string
::
npos
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -602,44 +602,57 @@ class CustomGradOpMaker<imperative::OpBase>
...
@@ -602,44 +602,57 @@ class CustomGradOpMaker<imperative::OpBase>
//////////// Operator and Kernel Register //////////////
//////////// Operator and Kernel Register //////////////
void
RegisterOperatorKernelWithPlace
(
const
std
::
string
&
name
,
static
void
RegisterOperatorKernelWithPlace
(
const
paddle
::
KernelFunc
&
kernel_func
,
const
std
::
string
&
name
,
const
proto
::
VarType
::
Type
type
,
const
OperatorWithKernel
::
OpKernelFunc
&
op_kernel_func
,
const
PlaceType
&
place
,
const
proto
::
VarType
::
Type
type
,
const
PlaceType
&
place
)
{
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
const
std
::
vector
<
std
::
string
>&
attrs
)
{
OpKernelType
key
(
type
,
experimental
::
ConvertExtPlaceToInnerPlace
(
place
));
OpKernelType
key
(
type
,
experimental
::
ConvertExtPlaceToInnerPlace
(
place
));
VLOG
(
3
)
<<
"Custom Operator: op kernel key: "
<<
key
;
VLOG
(
3
)
<<
"Custom Operator: op kernel key: "
<<
key
;
OperatorWithKernel
::
AllOpKernels
()[
name
][
key
]
=
OperatorWithKernel
::
AllOpKernels
()[
name
][
key
]
=
op_kernel_func
;
[
kernel_func
,
inputs
,
outputs
,
attrs
](
const
framework
::
ExecutionContext
&
ctx
)
{
VLOG
(
3
)
<<
"Custom Operator: run custom kernel func in lambda."
;
RunKernelFunc
(
ctx
,
kernel_func
,
inputs
,
outputs
,
attrs
);
};
}
}
void
RegisterOperatorKernel
(
const
std
::
string
&
name
,
static
void
RegisterOperatorKernel
(
const
std
::
string
&
name
,
const
paddle
::
KernelFunc
&
kernel_func
,
const
paddle
::
KernelFunc
&
kernel_func
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
const
std
::
vector
<
std
::
string
>&
attrs
)
{
const
std
::
vector
<
std
::
string
>&
attrs
,
void
*
dso_handle
)
{
VLOG
(
3
)
<<
"Custom Operator: op name in kernel: "
<<
name
;
VLOG
(
3
)
<<
"Custom Operator: op name in kernel: "
<<
name
;
// NOTE [ Dummy Op Kernel Key ]
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
// TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each
// op_kernel_key.place_, so we should register kernel for each
// device. But this is not entirely correct, if user only give a cpu kernel,
// device. But this is not entirely correct, if user only give a cpu kernel,
// but call api in gpu device, it will cause error.
// but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace
(
name
,
kernel_func
,
proto
::
VarType
::
RAW
,
OperatorWithKernel
::
OpKernelFunc
op_kernel_func
;
PlaceType
::
kCPU
,
inputs
,
outputs
,
attrs
);
if
(
kernel_func
)
{
VLOG
(
3
)
<<
"Register custom operator "
<<
name
<<
" with kernel func"
;
op_kernel_func
=
[
kernel_func
,
inputs
,
outputs
,
attrs
](
const
framework
::
ExecutionContext
&
ctx
)
{
VLOG
(
3
)
<<
"Custom Operator: run custom kernel func in lambda."
;
RunKernelFunc
(
ctx
,
kernel_func
,
inputs
,
outputs
,
attrs
);
};
}
else
{
VLOG
(
3
)
<<
"Register custom operator "
<<
name
<<
" with raw op kernel func"
;
PADDLE_ENFORCE_NOT_NULL
(
dso_handle
,
platform
::
errors
::
InvalidArgument
(
"The dso handle must be provided if kernel_func is nullptr."
));
using
OpKernelFuncPtr
=
void
(
const
framework
::
ExecutionContext
&
);
auto
symbol_name
=
"PD_"
+
name
+
"_raw_op_kernel_func"
;
auto
*
func
=
detail
::
DynLoad
<
OpKernelFuncPtr
>
(
dso_handle
,
symbol_name
);
op_kernel_func
=
func
;
}
RegisterOperatorKernelWithPlace
(
name
,
op_kernel_func
,
proto
::
VarType
::
RAW
,
PlaceType
::
kCPU
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
RegisterOperatorKernelWithPlace
(
name
,
kernel_func
,
proto
::
VarType
::
RAW
,
RegisterOperatorKernelWithPlace
(
name
,
op_
kernel_func
,
proto
::
VarType
::
RAW
,
PlaceType
::
kGPU
,
inputs
,
outputs
,
attrs
);
PlaceType
::
kGPU
);
#endif
#endif
}
}
void
RegisterOperatorWithMetaInfo
(
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
,
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
)
{
void
*
dso_handle
)
{
/* Op register */
/* Op register */
OpInfo
info
;
OpInfo
info
;
...
@@ -792,7 +805,8 @@ void RegisterOperatorWithMetaInfo(
...
@@ -792,7 +805,8 @@ void RegisterOperatorWithMetaInfo(
}
}
// Kernel func
// Kernel func
RegisterOperatorKernel
(
op_name
,
kernel_fn
,
op_inputs
,
op_outputs
,
op_attrs
);
RegisterOperatorKernel
(
op_name
,
kernel_fn
,
op_inputs
,
op_outputs
,
op_attrs
,
dso_handle
);
// If grad op or double grad op exists
// If grad op or double grad op exists
std
::
string
cur_op_name
=
op_name
;
std
::
string
cur_op_name
=
op_name
;
...
@@ -900,7 +914,7 @@ void RegisterOperatorWithMetaInfo(
...
@@ -900,7 +914,7 @@ void RegisterOperatorWithMetaInfo(
// Kernel func
// Kernel func
RegisterOperatorKernel
(
grad_op_name
,
grad_kernel_fn
,
grad_op_inputs
,
RegisterOperatorKernel
(
grad_op_name
,
grad_kernel_fn
,
grad_op_inputs
,
grad_op_outputs
,
grad_op_attrs
);
grad_op_outputs
,
grad_op_attrs
,
dso_handle
);
// update current info
// update current info
OpInfoMap
::
Instance
().
Insert
(
cur_op_name
,
info
);
OpInfoMap
::
Instance
().
Insert
(
cur_op_name
,
info
);
...
@@ -912,14 +926,14 @@ void RegisterOperatorWithMetaInfo(
...
@@ -912,14 +926,14 @@ void RegisterOperatorWithMetaInfo(
}
}
void
RegisterOperatorWithMetaInfoMap
(
void
RegisterOperatorWithMetaInfoMap
(
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
)
{
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
,
void
*
dso_handle
)
{
auto
&
meta_info_map
=
op_meta_info_map
.
GetMap
();
auto
&
meta_info_map
=
op_meta_info_map
.
GetMap
();
VLOG
(
3
)
<<
"Custom Operator: size of op meta info map - "
VLOG
(
3
)
<<
"Custom Operator: size of op meta info map - "
<<
meta_info_map
.
size
();
<<
meta_info_map
.
size
();
// pair: {op_type, OpMetaInfo}
// pair: {op_type, OpMetaInfo}
for
(
auto
&
pair
:
meta_info_map
)
{
for
(
auto
&
pair
:
meta_info_map
)
{
VLOG
(
3
)
<<
"Custom Operator: pair first -> op name: "
<<
pair
.
first
;
VLOG
(
3
)
<<
"Custom Operator: pair first -> op name: "
<<
pair
.
first
;
RegisterOperatorWithMetaInfo
(
pair
.
second
);
RegisterOperatorWithMetaInfo
(
pair
.
second
,
dso_handle
);
}
}
}
}
...
@@ -934,7 +948,7 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
...
@@ -934,7 +948,7 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
detail
::
DynLoad
<
get_op_meta_info_map_t
>
(
handle
,
"PD_GetOpMetaInfoMap"
);
detail
::
DynLoad
<
get_op_meta_info_map_t
>
(
handle
,
"PD_GetOpMetaInfoMap"
);
auto
&
op_meta_info_map
=
get_op_meta_info_map
();
auto
&
op_meta_info_map
=
get_op_meta_info_map
();
RegisterOperatorWithMetaInfoMap
(
op_meta_info_map
);
RegisterOperatorWithMetaInfoMap
(
op_meta_info_map
,
handle
);
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/custom_operator.h
浏览文件 @
5c3873f6
...
@@ -26,10 +26,11 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
...
@@ -26,10 +26,11 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
// Register custom op api: register op directly
// Register custom op api: register op directly
void
RegisterOperatorWithMetaInfoMap
(
void
RegisterOperatorWithMetaInfoMap
(
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
);
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
,
void
*
dso_handle
=
nullptr
);
// Interface for selective register custom op.
// Interface for selective register custom op.
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
);
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
,
void
*
dso_handle
=
nullptr
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/custom_raw_op_kernel_func.h
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/pten/api/ext/op_meta_info.h"
// NOTE(zengjinle): this macro is only for internal usage. Commonly, users
// should not use this macro.
#define __PD_DEFINE_RAW_OP_KERNEL_FUNC(op_name, ctx) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_raw_op_kernel_func__##op_name, \
"__PD_DEFINE_RAW_KERNEL_FUNC must be called in global namespace."); \
extern "C" void PD_##op_name##_raw_op_kernel_func( \
const ::paddle::framework::ExecutionContext& ctx)
paddle/fluid/pybind/pybind.cc
浏览文件 @
5c3873f6
...
@@ -185,6 +185,14 @@ bool IsCompiledWithCUDA() {
...
@@ -185,6 +185,14 @@ bool IsCompiledWithCUDA() {
#endif
#endif
}
}
bool
IsCompiledWithNCCL
()
{
#ifdef PADDLE_WITH_NCCL
return
true
;
#else
return
false
;
#endif
}
bool
IsCompiledWithROCM
()
{
bool
IsCompiledWithROCM
()
{
#ifndef PADDLE_WITH_HIP
#ifndef PADDLE_WITH_HIP
return
false
;
return
false
;
...
@@ -2433,6 +2441,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2433,6 +2441,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"is_compiled_with_ipu"
,
IsCompiledWithIPU
);
m
.
def
(
"is_compiled_with_ipu"
,
IsCompiledWithIPU
);
m
.
def
(
"is_compiled_with_xpu"
,
IsCompiledWithXPU
);
m
.
def
(
"is_compiled_with_xpu"
,
IsCompiledWithXPU
);
m
.
def
(
"is_compiled_with_mkldnn"
,
IsCompiledWithMKLDNN
);
m
.
def
(
"is_compiled_with_mkldnn"
,
IsCompiledWithMKLDNN
);
m
.
def
(
"is_compiled_with_nccl"
,
IsCompiledWithNCCL
);
m
.
def
(
"is_compiled_with_cinn"
,
IsCompiledWithCINN
);
m
.
def
(
"is_compiled_with_cinn"
,
IsCompiledWithCINN
);
m
.
def
(
"is_compiled_with_mlu"
,
IsCompiledWithMLU
);
m
.
def
(
"is_compiled_with_mlu"
,
IsCompiledWithMLU
);
m
.
def
(
"_is_compiled_with_heterps"
,
IsCompiledWithHETERPS
);
m
.
def
(
"_is_compiled_with_heterps"
,
IsCompiledWithHETERPS
);
...
...
python/paddle/fluid/tests/custom_op/CMakeLists.txt
浏览文件 @
5c3873f6
...
@@ -10,6 +10,9 @@ if(WITH_GPU OR APPLE)
...
@@ -10,6 +10,9 @@ if(WITH_GPU OR APPLE)
set_tests_properties
(
test_custom_relu_model PROPERTIES TIMEOUT 180
)
set_tests_properties
(
test_custom_relu_model PROPERTIES TIMEOUT 180
)
endif
()
endif
()
py_test
(
test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py
)
set_tests_properties
(
test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180
)
# CPU custom op tests: only compile .cc file
# CPU custom op tests: only compile .cc file
py_test
(
test_dispatch_jit SRCS test_dispatch_jit.py
)
py_test
(
test_dispatch_jit SRCS test_dispatch_jit.py
)
py_test
(
test_multi_out_jit SRCS test_multi_out_jit.py
)
py_test
(
test_multi_out_jit SRCS test_multi_out_jit.py
)
...
...
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "custom_raw_op_kernel_op.h" // NOLINT
#include "paddle/fluid/framework/custom_raw_op_kernel_func.h"
#include "paddle/fluid/platform/enforce.h"
void
ReluCPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
custom_raw_op
::
ReluForward
(
x
,
y
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void
ReluGPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
);
#else
void
ReluGPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unimplemented
(
"ReluGPUForward is not supported when not compiled with GPU."
));
}
#endif
__PD_DEFINE_RAW_OP_KERNEL_FUNC
(
custom_raw_relu
,
ctx
)
{
namespace
f
=
paddle
::
framework
;
const
auto
*
x
=
ctx
.
Input
<
f
::
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
f
::
Tensor
>
(
"Y"
);
PADDLE_ENFORCE_NOT_NULL
(
x
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Input(X) should not be nullptr."
));
PADDLE_ENFORCE_NOT_NULL
(
y
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Input(X) should not be nullptr."
));
if
(
paddle
::
platform
::
is_gpu_place
(
x
->
place
()))
{
ReluGPUForward
(
*
x
,
y
);
}
else
{
ReluCPUForward
(
*
x
,
y
);
}
}
PD_BUILD_OP
(
custom_raw_relu
).
Inputs
({
"X"
}).
Outputs
({
"Y"
});
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "custom_raw_op_kernel_op.h" // NOLINT
void
ReluGPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
custom_raw_op
::
ReluForward
(
x
,
y
);
}
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
namespace
custom_raw_op
{
struct
ReluFunctor
{
explicit
ReluFunctor
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
:
x_
(
x
),
y_
(
y
)
{}
template
<
typename
U
>
struct
Impl
{
Impl
(
const
U
*
x
,
U
*
y
)
:
x_
(
x
),
y_
(
y
)
{}
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
y_
[
i
]
=
(
x_
[
i
]
>
static_cast
<
U
>
(
0
)
?
x_
[
i
]
:
static_cast
<
U
>
(
0
));
}
private:
const
U
*
x_
;
U
*
y_
;
};
template
<
typename
T
>
void
apply
()
{
auto
n
=
x_
.
numel
();
auto
place
=
x_
.
place
();
const
auto
*
x_data
=
x_
.
data
<
T
>
();
y_
->
Resize
(
x_
.
dims
());
auto
*
y_data
=
y_
->
mutable_data
<
T
>
(
place
);
const
auto
&
dev_ctx
=
*
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
#define LAUNCH_RELU_KERNEL(DevCtxT) \
do { \
auto &__dev_ctx = dynamic_cast<const DevCtxT &>(dev_ctx); \
paddle::platform::ForRange<DevCtxT> for_range(__dev_ctx, n); \
Impl<T> functor(x_data, y_data); \
for_range(functor); \
} while (0)
#if defined(__NVCC__) || defined(__HIPCC__)
if
(
paddle
::
platform
::
is_gpu_place
(
place
))
{
LAUNCH_RELU_KERNEL
(
paddle
::
platform
::
CUDADeviceContext
);
return
;
}
#endif
LAUNCH_RELU_KERNEL
(
paddle
::
platform
::
CPUDeviceContext
);
#undef LAUNCH_RELU_KERNEL
}
private:
const
paddle
::
framework
::
Tensor
&
x_
;
paddle
::
framework
::
Tensor
*
y_
;
};
inline
void
ReluForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
custom_raw_op
::
ReluFunctor
functor
(
x
,
y
);
paddle
::
framework
::
VisitDataType
(
x
.
type
(),
functor
);
}
}
// namespace custom_raw_op
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py
0 → 100644
浏览文件 @
5c3873f6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.fluid.core
as
core
from
paddle.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
setup
from
utils
import
paddle_includes
,
extra_compile_args
if
paddle
.
is_compiled_with_cuda
():
sources
=
[
'custom_raw_op_kernel_op.cc'
,
'custom_raw_op_kernel_op.cu'
]
extension
=
CUDAExtension
else
:
sources
=
[
'custom_raw_op_kernel_op.cc'
]
extension
=
CppExtension
cwd
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
os
.
chdir
(
cwd
)
if
os
.
name
==
'nt'
:
compile_dir
=
os
.
path
.
join
(
os
.
environ
[
'work_dir'
],
os
.
environ
[
'BUILD_DIR'
])
else
:
compile_dir
=
os
.
path
.
join
(
os
.
environ
[
'PADDLE_ROOT'
],
'build'
)
macros
=
[]
if
core
.
is_compiled_with_mkldnn
():
macros
.
append
((
"PADDLE_WITH_MKLDNN"
,
None
))
if
core
.
is_compiled_with_nccl
():
macros
.
append
((
"PADDLE_WITH_NCCL"
,
None
))
include_dirs
=
list
(
paddle_includes
)
+
[
cwd
]
setup
(
name
=
os
.
getenv
(
"MODULE_NAME"
,
"custom_raw_op_kernel_op_setup"
),
ext_modules
=
extension
(
sources
=
sources
,
include_dirs
=
include_dirs
,
extra_compile_args
=
extra_compile_args
,
_compile_dir
=
compile_dir
,
define_macros
=
macros
))
python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py
0 → 100644
浏览文件 @
5c3873f6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
shlex
import
site
import
sys
import
importlib
import
unittest
import
numpy
as
np
MODULE_NAME
=
"custom_raw_op_kernel_op_lib"
def
prepare_module_path
():
# NOTE(Aurelius84): Normally, it's no need to add following codes for users.
# But we simulate to pip install in current process, so interpreter don't snap
# sys.path has been updated. So we update it manually.
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
if
os
.
name
==
'nt'
:
# NOTE(zhouwei25): getsitepackages on windows will return a list: [python install dir, site packages dir]
site_dir
=
site
.
getsitepackages
()[
1
]
else
:
site_dir
=
site
.
getsitepackages
()[
0
]
custom_egg_path
=
[
x
for
x
in
os
.
listdir
(
site_dir
)
if
MODULE_NAME
in
x
]
assert
len
(
custom_egg_path
)
==
1
,
"Matched egg number is %d."
%
len
(
custom_egg_path
)
sys
.
path
.
append
(
os
.
path
.
join
(
site_dir
,
custom_egg_path
[
0
]))
# FIXME(zengjinle): do not know how to get the _compile_dir argument
# on Windows CI when compiling the custom op. Skip it on Windows CI
# temporarily.
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
"Windows does not support yet."
)
class
TestCustomRawReluOp
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
path
=
os
.
path
.
join
(
path
,
"custom_raw_op_kernel_op_setup.py"
)
cmd
=
[
sys
.
executable
,
path
,
"install"
,
"--force"
]
cmd
=
" "
.
join
([
shlex
.
quote
(
c
)
for
c
in
cmd
])
os
.
environ
[
'MODULE_NAME'
]
=
MODULE_NAME
assert
os
.
system
(
cmd
)
==
0
prepare_module_path
()
@
classmethod
def
tearDownClass
(
cls
):
cmd
=
[
sys
.
executable
,
"-m"
,
"pip"
,
"uninstall"
,
"-y"
,
MODULE_NAME
]
cmd
=
" "
.
join
([
shlex
.
quote
(
c
)
for
c
in
cmd
])
assert
os
.
system
(
cmd
)
==
0
def
custom_raw_relu
(
self
,
x
):
module
=
importlib
.
import_module
(
MODULE_NAME
)
custom_raw_relu_op
=
getattr
(
module
,
"custom_raw_relu"
)
self
.
assertTrue
(
custom_raw_relu_op
is
not
None
)
return
custom_raw_relu_op
(
x
)
def
test_dygraph
(
self
):
x
=
paddle
.
to_tensor
(
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
[
2
,
3
]))
y1
=
self
.
custom_raw_relu
(
x
)
y2
=
paddle
.
nn
.
ReLU
()(
x
)
self
.
assertTrue
(
np
.
array_equal
(
y1
.
numpy
(),
y2
.
numpy
()))
def
test_static
(
self
):
paddle
.
enable_static
()
shape
=
[
2
,
3
]
x
=
paddle
.
static
.
data
(
name
=
"x"
,
dtype
=
'float32'
,
shape
=
shape
)
y1
=
self
.
custom_raw_relu
(
x
)
y2
=
paddle
.
nn
.
ReLU
()(
x
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
paddle
.
static
.
default_startup_program
())
x_np
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
[
2
,
3
]).
astype
(
'float32'
)
y1_value
,
y2_value
=
exe
.
run
(
paddle
.
static
.
default_main_program
(),
feed
=
{
x
.
name
:
x_np
},
fetch_list
=
[
y1
,
y2
])
self
.
assertTrue
(
np
.
array_equal
(
y1_value
,
y2_value
))
paddle
.
disable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录