Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5c3873f6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c3873f6
编写于
2月 08, 2022
作者:
S
sneaxiy
提交者:
GitHub
2月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add __PD_DEFINE_RAW_OP_KERNEL_FUNC for registering custom op kernel with ExecutionContext (#39352)
* hack custom op * add ut * skip windows ci
上级
fee4316d
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
395 addition
and
37 deletion
+395
-37
paddle/fluid/framework/custom_operator.cc
paddle/fluid/framework/custom_operator.cc
+49
-35
paddle/fluid/framework/custom_operator.h
paddle/fluid/framework/custom_operator.h
+3
-2
paddle/fluid/framework/custom_raw_op_kernel_func.h
paddle/fluid/framework/custom_raw_op_kernel_func.h
+27
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+9
-0
python/paddle/fluid/tests/custom_op/CMakeLists.txt
python/paddle/fluid/tests/custom_op/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc
...n/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc
+52
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu
...n/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu
+21
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h
...on/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h
+84
-0
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py
...le/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py
+50
-0
python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py
...dle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py
+97
-0
未找到文件。
paddle/fluid/framework/custom_operator.cc
浏览文件 @
5c3873f6
...
...
@@ -61,27 +61,27 @@ static T* DynLoad(void* handle, std::string name) {
return
func
;
}
inline
bool
IsGradVar
(
const
std
::
string
&
var_name
)
{
inline
static
bool
IsGradVar
(
const
std
::
string
&
var_name
)
{
std
::
string
suffix
=
kGradVarSuffix
;
return
var_name
.
rfind
(
suffix
)
!=
std
::
string
::
npos
;
}
inline
bool
IsDuplicableVar
(
const
std
::
string
&
var_name
)
{
inline
static
bool
IsDuplicableVar
(
const
std
::
string
&
var_name
)
{
std
::
string
suffix
=
kTensorVectorSuffix
;
return
var_name
.
rfind
(
suffix
)
!=
std
::
string
::
npos
;
}
inline
std
::
string
NoGrad
(
const
std
::
string
&
var_name
)
{
inline
st
atic
st
d
::
string
NoGrad
(
const
std
::
string
&
var_name
)
{
std
::
string
suffix
=
kGradVarSuffix
;
return
var_name
.
substr
(
0
,
var_name
.
size
()
-
kGradVarSuffixSize
);
}
inline
bool
IsMemberOf
(
const
std
::
vector
<
std
::
string
>&
vec
,
inline
static
bool
IsMemberOf
(
const
std
::
vector
<
std
::
string
>&
vec
,
const
std
::
string
&
name
)
{
return
std
::
find
(
vec
.
cbegin
(),
vec
.
cend
(),
name
)
!=
vec
.
cend
();
}
std
::
vector
<
std
::
string
>
ParseAttrStr
(
const
std
::
string
&
attr
)
{
st
atic
st
d
::
vector
<
std
::
string
>
ParseAttrStr
(
const
std
::
string
&
attr
)
{
auto
split_pos
=
attr
.
find_first_of
(
":"
);
PADDLE_ENFORCE_NE
(
split_pos
,
std
::
string
::
npos
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -602,44 +602,57 @@ class CustomGradOpMaker<imperative::OpBase>
//////////// Operator and Kernel Register //////////////
void
RegisterOperatorKernelWithPlace
(
const
std
::
string
&
name
,
const
paddle
::
KernelFunc
&
kernel_func
,
const
proto
::
VarType
::
Type
type
,
const
PlaceType
&
place
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
const
std
::
vector
<
std
::
string
>&
attrs
)
{
static
void
RegisterOperatorKernelWithPlace
(
const
std
::
string
&
name
,
const
OperatorWithKernel
::
OpKernelFunc
&
op_kernel_func
,
const
proto
::
VarType
::
Type
type
,
const
PlaceType
&
place
)
{
OpKernelType
key
(
type
,
experimental
::
ConvertExtPlaceToInnerPlace
(
place
));
VLOG
(
3
)
<<
"Custom Operator: op kernel key: "
<<
key
;
OperatorWithKernel
::
AllOpKernels
()[
name
][
key
]
=
[
kernel_func
,
inputs
,
outputs
,
attrs
](
const
framework
::
ExecutionContext
&
ctx
)
{
VLOG
(
3
)
<<
"Custom Operator: run custom kernel func in lambda."
;
RunKernelFunc
(
ctx
,
kernel_func
,
inputs
,
outputs
,
attrs
);
};
OperatorWithKernel
::
AllOpKernels
()[
name
][
key
]
=
op_kernel_func
;
}
void
RegisterOperatorKernel
(
const
std
::
string
&
name
,
static
void
RegisterOperatorKernel
(
const
std
::
string
&
name
,
const
paddle
::
KernelFunc
&
kernel_func
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
const
std
::
vector
<
std
::
string
>&
attrs
)
{
const
std
::
vector
<
std
::
string
>&
attrs
,
void
*
dso_handle
)
{
VLOG
(
3
)
<<
"Custom Operator: op name in kernel: "
<<
name
;
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each
// device. But this is not entirely correct, if user only give a cpu kernel,
// but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace
(
name
,
kernel_func
,
proto
::
VarType
::
RAW
,
PlaceType
::
kCPU
,
inputs
,
outputs
,
attrs
);
OperatorWithKernel
::
OpKernelFunc
op_kernel_func
;
if
(
kernel_func
)
{
VLOG
(
3
)
<<
"Register custom operator "
<<
name
<<
" with kernel func"
;
op_kernel_func
=
[
kernel_func
,
inputs
,
outputs
,
attrs
](
const
framework
::
ExecutionContext
&
ctx
)
{
VLOG
(
3
)
<<
"Custom Operator: run custom kernel func in lambda."
;
RunKernelFunc
(
ctx
,
kernel_func
,
inputs
,
outputs
,
attrs
);
};
}
else
{
VLOG
(
3
)
<<
"Register custom operator "
<<
name
<<
" with raw op kernel func"
;
PADDLE_ENFORCE_NOT_NULL
(
dso_handle
,
platform
::
errors
::
InvalidArgument
(
"The dso handle must be provided if kernel_func is nullptr."
));
using
OpKernelFuncPtr
=
void
(
const
framework
::
ExecutionContext
&
);
auto
symbol_name
=
"PD_"
+
name
+
"_raw_op_kernel_func"
;
auto
*
func
=
detail
::
DynLoad
<
OpKernelFuncPtr
>
(
dso_handle
,
symbol_name
);
op_kernel_func
=
func
;
}
RegisterOperatorKernelWithPlace
(
name
,
op_kernel_func
,
proto
::
VarType
::
RAW
,
PlaceType
::
kCPU
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
RegisterOperatorKernelWithPlace
(
name
,
kernel_func
,
proto
::
VarType
::
RAW
,
PlaceType
::
kGPU
,
inputs
,
outputs
,
attrs
);
RegisterOperatorKernelWithPlace
(
name
,
op_
kernel_func
,
proto
::
VarType
::
RAW
,
PlaceType
::
kGPU
);
#endif
}
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
)
{
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
,
void
*
dso_handle
)
{
/* Op register */
OpInfo
info
;
...
...
@@ -792,7 +805,8 @@ void RegisterOperatorWithMetaInfo(
}
// Kernel func
RegisterOperatorKernel
(
op_name
,
kernel_fn
,
op_inputs
,
op_outputs
,
op_attrs
);
RegisterOperatorKernel
(
op_name
,
kernel_fn
,
op_inputs
,
op_outputs
,
op_attrs
,
dso_handle
);
// If grad op or double grad op exists
std
::
string
cur_op_name
=
op_name
;
...
...
@@ -900,7 +914,7 @@ void RegisterOperatorWithMetaInfo(
// Kernel func
RegisterOperatorKernel
(
grad_op_name
,
grad_kernel_fn
,
grad_op_inputs
,
grad_op_outputs
,
grad_op_attrs
);
grad_op_outputs
,
grad_op_attrs
,
dso_handle
);
// update current info
OpInfoMap
::
Instance
().
Insert
(
cur_op_name
,
info
);
...
...
@@ -912,14 +926,14 @@ void RegisterOperatorWithMetaInfo(
}
void
RegisterOperatorWithMetaInfoMap
(
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
)
{
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
,
void
*
dso_handle
)
{
auto
&
meta_info_map
=
op_meta_info_map
.
GetMap
();
VLOG
(
3
)
<<
"Custom Operator: size of op meta info map - "
<<
meta_info_map
.
size
();
// pair: {op_type, OpMetaInfo}
for
(
auto
&
pair
:
meta_info_map
)
{
VLOG
(
3
)
<<
"Custom Operator: pair first -> op name: "
<<
pair
.
first
;
RegisterOperatorWithMetaInfo
(
pair
.
second
);
RegisterOperatorWithMetaInfo
(
pair
.
second
,
dso_handle
);
}
}
...
...
@@ -934,7 +948,7 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
detail
::
DynLoad
<
get_op_meta_info_map_t
>
(
handle
,
"PD_GetOpMetaInfoMap"
);
auto
&
op_meta_info_map
=
get_op_meta_info_map
();
RegisterOperatorWithMetaInfoMap
(
op_meta_info_map
);
RegisterOperatorWithMetaInfoMap
(
op_meta_info_map
,
handle
);
}
}
// namespace framework
...
...
paddle/fluid/framework/custom_operator.h
浏览文件 @
5c3873f6
...
...
@@ -26,10 +26,11 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
// Register custom op api: register op directly
void
RegisterOperatorWithMetaInfoMap
(
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
);
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
,
void
*
dso_handle
=
nullptr
);
// Interface for selective register custom op.
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
);
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
,
void
*
dso_handle
=
nullptr
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/custom_raw_op_kernel_func.h
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/pten/api/ext/op_meta_info.h"
// NOTE(zengjinle): this macro is only for internal usage. Commonly, users
// should not use this macro.
#define __PD_DEFINE_RAW_OP_KERNEL_FUNC(op_name, ctx) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_raw_op_kernel_func__##op_name, \
"__PD_DEFINE_RAW_KERNEL_FUNC must be called in global namespace."); \
extern "C" void PD_##op_name##_raw_op_kernel_func( \
const ::paddle::framework::ExecutionContext& ctx)
paddle/fluid/pybind/pybind.cc
浏览文件 @
5c3873f6
...
...
@@ -185,6 +185,14 @@ bool IsCompiledWithCUDA() {
#endif
}
bool
IsCompiledWithNCCL
()
{
#ifdef PADDLE_WITH_NCCL
return
true
;
#else
return
false
;
#endif
}
bool
IsCompiledWithROCM
()
{
#ifndef PADDLE_WITH_HIP
return
false
;
...
...
@@ -2433,6 +2441,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"is_compiled_with_ipu"
,
IsCompiledWithIPU
);
m
.
def
(
"is_compiled_with_xpu"
,
IsCompiledWithXPU
);
m
.
def
(
"is_compiled_with_mkldnn"
,
IsCompiledWithMKLDNN
);
m
.
def
(
"is_compiled_with_nccl"
,
IsCompiledWithNCCL
);
m
.
def
(
"is_compiled_with_cinn"
,
IsCompiledWithCINN
);
m
.
def
(
"is_compiled_with_mlu"
,
IsCompiledWithMLU
);
m
.
def
(
"_is_compiled_with_heterps"
,
IsCompiledWithHETERPS
);
...
...
python/paddle/fluid/tests/custom_op/CMakeLists.txt
浏览文件 @
5c3873f6
...
...
@@ -10,6 +10,9 @@ if(WITH_GPU OR APPLE)
set_tests_properties
(
test_custom_relu_model PROPERTIES TIMEOUT 180
)
endif
()
py_test
(
test_custom_raw_op_kernel_op SRCS test_custom_raw_op_kernel_op.py
)
set_tests_properties
(
test_custom_raw_op_kernel_op PROPERTIES TIMEOUT 180
)
# CPU custom op tests: only compile .cc file
py_test
(
test_dispatch_jit SRCS test_dispatch_jit.py
)
py_test
(
test_multi_out_jit SRCS test_multi_out_jit.py
)
...
...
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cc
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "custom_raw_op_kernel_op.h" // NOLINT
#include "paddle/fluid/framework/custom_raw_op_kernel_func.h"
#include "paddle/fluid/platform/enforce.h"
void
ReluCPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
custom_raw_op
::
ReluForward
(
x
,
y
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void
ReluGPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
);
#else
void
ReluGPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unimplemented
(
"ReluGPUForward is not supported when not compiled with GPU."
));
}
#endif
__PD_DEFINE_RAW_OP_KERNEL_FUNC
(
custom_raw_relu
,
ctx
)
{
namespace
f
=
paddle
::
framework
;
const
auto
*
x
=
ctx
.
Input
<
f
::
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
f
::
Tensor
>
(
"Y"
);
PADDLE_ENFORCE_NOT_NULL
(
x
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Input(X) should not be nullptr."
));
PADDLE_ENFORCE_NOT_NULL
(
y
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Input(X) should not be nullptr."
));
if
(
paddle
::
platform
::
is_gpu_place
(
x
->
place
()))
{
ReluGPUForward
(
*
x
,
y
);
}
else
{
ReluCPUForward
(
*
x
,
y
);
}
}
PD_BUILD_OP
(
custom_raw_relu
).
Inputs
({
"X"
}).
Outputs
({
"Y"
});
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.cu
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "custom_raw_op_kernel_op.h" // NOLINT
void
ReluGPUForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
custom_raw_op
::
ReluForward
(
x
,
y
);
}
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op.h
0 → 100644
浏览文件 @
5c3873f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
namespace
custom_raw_op
{
struct
ReluFunctor
{
explicit
ReluFunctor
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
:
x_
(
x
),
y_
(
y
)
{}
template
<
typename
U
>
struct
Impl
{
Impl
(
const
U
*
x
,
U
*
y
)
:
x_
(
x
),
y_
(
y
)
{}
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
y_
[
i
]
=
(
x_
[
i
]
>
static_cast
<
U
>
(
0
)
?
x_
[
i
]
:
static_cast
<
U
>
(
0
));
}
private:
const
U
*
x_
;
U
*
y_
;
};
template
<
typename
T
>
void
apply
()
{
auto
n
=
x_
.
numel
();
auto
place
=
x_
.
place
();
const
auto
*
x_data
=
x_
.
data
<
T
>
();
y_
->
Resize
(
x_
.
dims
());
auto
*
y_data
=
y_
->
mutable_data
<
T
>
(
place
);
const
auto
&
dev_ctx
=
*
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
#define LAUNCH_RELU_KERNEL(DevCtxT) \
do { \
auto &__dev_ctx = dynamic_cast<const DevCtxT &>(dev_ctx); \
paddle::platform::ForRange<DevCtxT> for_range(__dev_ctx, n); \
Impl<T> functor(x_data, y_data); \
for_range(functor); \
} while (0)
#if defined(__NVCC__) || defined(__HIPCC__)
if
(
paddle
::
platform
::
is_gpu_place
(
place
))
{
LAUNCH_RELU_KERNEL
(
paddle
::
platform
::
CUDADeviceContext
);
return
;
}
#endif
LAUNCH_RELU_KERNEL
(
paddle
::
platform
::
CPUDeviceContext
);
#undef LAUNCH_RELU_KERNEL
}
private:
const
paddle
::
framework
::
Tensor
&
x_
;
paddle
::
framework
::
Tensor
*
y_
;
};
inline
void
ReluForward
(
const
paddle
::
framework
::
Tensor
&
x
,
paddle
::
framework
::
Tensor
*
y
)
{
custom_raw_op
::
ReluFunctor
functor
(
x
,
y
);
paddle
::
framework
::
VisitDataType
(
x
.
type
(),
functor
);
}
}
// namespace custom_raw_op
python/paddle/fluid/tests/custom_op/custom_raw_op_kernel_op_setup.py
0 → 100644
浏览文件 @
5c3873f6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.fluid.core
as
core
from
paddle.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
setup
from
utils
import
paddle_includes
,
extra_compile_args
if
paddle
.
is_compiled_with_cuda
():
sources
=
[
'custom_raw_op_kernel_op.cc'
,
'custom_raw_op_kernel_op.cu'
]
extension
=
CUDAExtension
else
:
sources
=
[
'custom_raw_op_kernel_op.cc'
]
extension
=
CppExtension
cwd
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
os
.
chdir
(
cwd
)
if
os
.
name
==
'nt'
:
compile_dir
=
os
.
path
.
join
(
os
.
environ
[
'work_dir'
],
os
.
environ
[
'BUILD_DIR'
])
else
:
compile_dir
=
os
.
path
.
join
(
os
.
environ
[
'PADDLE_ROOT'
],
'build'
)
macros
=
[]
if
core
.
is_compiled_with_mkldnn
():
macros
.
append
((
"PADDLE_WITH_MKLDNN"
,
None
))
if
core
.
is_compiled_with_nccl
():
macros
.
append
((
"PADDLE_WITH_NCCL"
,
None
))
include_dirs
=
list
(
paddle_includes
)
+
[
cwd
]
setup
(
name
=
os
.
getenv
(
"MODULE_NAME"
,
"custom_raw_op_kernel_op_setup"
),
ext_modules
=
extension
(
sources
=
sources
,
include_dirs
=
include_dirs
,
extra_compile_args
=
extra_compile_args
,
_compile_dir
=
compile_dir
,
define_macros
=
macros
))
python/paddle/fluid/tests/custom_op/test_custom_raw_op_kernel_op.py
0 → 100644
浏览文件 @
5c3873f6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
shlex
import
site
import
sys
import
importlib
import
unittest
import
numpy
as
np
MODULE_NAME
=
"custom_raw_op_kernel_op_lib"
def
prepare_module_path
():
# NOTE(Aurelius84): Normally, it's no need to add following codes for users.
# But we simulate to pip install in current process, so interpreter don't snap
# sys.path has been updated. So we update it manually.
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
if
os
.
name
==
'nt'
:
# NOTE(zhouwei25): getsitepackages on windows will return a list: [python install dir, site packages dir]
site_dir
=
site
.
getsitepackages
()[
1
]
else
:
site_dir
=
site
.
getsitepackages
()[
0
]
custom_egg_path
=
[
x
for
x
in
os
.
listdir
(
site_dir
)
if
MODULE_NAME
in
x
]
assert
len
(
custom_egg_path
)
==
1
,
"Matched egg number is %d."
%
len
(
custom_egg_path
)
sys
.
path
.
append
(
os
.
path
.
join
(
site_dir
,
custom_egg_path
[
0
]))
# FIXME(zengjinle): do not know how to get the _compile_dir argument
# on Windows CI when compiling the custom op. Skip it on Windows CI
# temporarily.
@
unittest
.
skipIf
(
os
.
name
==
"nt"
,
"Windows does not support yet."
)
class
TestCustomRawReluOp
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
path
=
os
.
path
.
join
(
path
,
"custom_raw_op_kernel_op_setup.py"
)
cmd
=
[
sys
.
executable
,
path
,
"install"
,
"--force"
]
cmd
=
" "
.
join
([
shlex
.
quote
(
c
)
for
c
in
cmd
])
os
.
environ
[
'MODULE_NAME'
]
=
MODULE_NAME
assert
os
.
system
(
cmd
)
==
0
prepare_module_path
()
@
classmethod
def
tearDownClass
(
cls
):
cmd
=
[
sys
.
executable
,
"-m"
,
"pip"
,
"uninstall"
,
"-y"
,
MODULE_NAME
]
cmd
=
" "
.
join
([
shlex
.
quote
(
c
)
for
c
in
cmd
])
assert
os
.
system
(
cmd
)
==
0
def
custom_raw_relu
(
self
,
x
):
module
=
importlib
.
import_module
(
MODULE_NAME
)
custom_raw_relu_op
=
getattr
(
module
,
"custom_raw_relu"
)
self
.
assertTrue
(
custom_raw_relu_op
is
not
None
)
return
custom_raw_relu_op
(
x
)
def
test_dygraph
(
self
):
x
=
paddle
.
to_tensor
(
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
[
2
,
3
]))
y1
=
self
.
custom_raw_relu
(
x
)
y2
=
paddle
.
nn
.
ReLU
()(
x
)
self
.
assertTrue
(
np
.
array_equal
(
y1
.
numpy
(),
y2
.
numpy
()))
def
test_static
(
self
):
paddle
.
enable_static
()
shape
=
[
2
,
3
]
x
=
paddle
.
static
.
data
(
name
=
"x"
,
dtype
=
'float32'
,
shape
=
shape
)
y1
=
self
.
custom_raw_relu
(
x
)
y2
=
paddle
.
nn
.
ReLU
()(
x
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
paddle
.
static
.
default_startup_program
())
x_np
=
np
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
[
2
,
3
]).
astype
(
'float32'
)
y1_value
,
y2_value
=
exe
.
run
(
paddle
.
static
.
default_main_program
(),
feed
=
{
x
.
name
:
x_np
},
fetch_list
=
[
y1
,
y2
])
self
.
assertTrue
(
np
.
array_equal
(
y1_value
,
y2_value
))
paddle
.
disable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录