Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2135020a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
2135020a
编写于
2月 21, 2023
作者:
H
HongyuJia
提交者:
GitHub
2月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cpp Extension] Add unittest, mixed calling of op and extension (#50678)
* testcase init commit * skip sign conflict * fix year
上级
397c9403
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
424 addition
and
17 deletion
+424
-17
python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension.cc
...addle/fluid/tests/cpp_extension/mix_relu_and_extension.cc
+163
-0
python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension_setup.py
...fluid/tests/cpp_extension/mix_relu_and_extension_setup.py
+30
-0
python/paddle/fluid/tests/cpp_extension/test_cpp_extension_setup.py
...dle/fluid/tests/cpp_extension/test_cpp_extension_setup.py
+192
-17
python/paddle/fluid/tests/cpp_extension/utils.py
python/paddle/fluid/tests/cpp_extension/utils.py
+39
-0
未找到文件。
python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension.cc
0 → 100644
浏览文件 @
2135020a
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "custom_power.h" // NOLINT
#include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
template
<
typename
data_t
>
void
relu_cpu_forward_kernel
(
const
data_t
*
x_data
,
data_t
*
out_data
,
int64_t
x_numel
)
{
PD_CHECK
(
x_data
!=
nullptr
,
"x_data is nullptr."
);
PD_CHECK
(
out_data
!=
nullptr
,
"out_data is nullptr."
);
for
(
int64_t
i
=
0
;
i
<
x_numel
;
++
i
)
{
out_data
[
i
]
=
std
::
max
(
static_cast
<
data_t
>
(
0.
),
x_data
[
i
]);
}
}
template
<
typename
data_t
>
void
relu_cpu_backward_kernel
(
const
data_t
*
grad_out_data
,
const
data_t
*
out_data
,
data_t
*
grad_x_data
,
int64_t
out_numel
)
{
for
(
int64_t
i
=
0
;
i
<
out_numel
;
++
i
)
{
grad_x_data
[
i
]
=
grad_out_data
[
i
]
*
(
out_data
[
i
]
>
static_cast
<
data_t
>
(
0
)
?
1.
:
0.
);
}
}
template
<
typename
data_t
>
void
relu_cpu_double_backward_kernel
(
const
data_t
*
out_data
,
const
data_t
*
ddx_data
,
data_t
*
ddout_data
,
int64_t
ddout_numel
)
{
for
(
int64_t
i
=
0
;
i
<
ddout_numel
;
++
i
)
{
ddout_data
[
i
]
=
ddx_data
[
i
]
*
(
out_data
[
i
]
>
static_cast
<
data_t
>
(
0
)
?
1.
:
0.
);
}
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_forward
(
const
paddle
::
Tensor
&
x
)
{
CHECK_CPU_INPUT
(
x
);
auto
out
=
paddle
::
empty_like
(
x
);
PD_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"relu_cpu_forward"
,
([
&
]
{
relu_cpu_forward_kernel
<
data_t
>
(
x
.
data
<
data_t
>
(),
out
.
data
<
data_t
>
(),
x
.
numel
());
}));
return
{
out
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_backward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
auto
grad_x
=
paddle
::
empty_like
(
x
);
PD_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"relu_cpu_backward"
,
([
&
]
{
relu_cpu_backward_kernel
<
data_t
>
(
grad_out
.
data
<
data_t
>
(),
out
.
data
<
data_t
>
(),
grad_x
.
data
<
data_t
>
(),
out
.
size
());
}));
return
{
grad_x
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_double_backward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
CHECK_CPU_INPUT
(
out
);
CHECK_CPU_INPUT
(
ddx
);
auto
ddout
=
paddle
::
empty
(
out
.
shape
(),
out
.
dtype
(),
out
.
place
());
PD_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"relu_cpu_double_backward"
,
([
&
]
{
relu_cpu_double_backward_kernel
<
data_t
>
(
out
.
data
<
data_t
>
(),
ddx
.
data
<
data_t
>
(),
ddout
.
mutable_data
<
data_t
>
(
out
.
place
()),
ddout
.
size
());
}));
return
{
ddout
};
}
std
::
vector
<
paddle
::
Tensor
>
ReluForward
(
const
paddle
::
Tensor
&
x
)
{
if
(
x
.
is_cpu
())
{
return
relu_cpu_forward
(
x
);
}
else
{
PD_THROW
(
"Not implemented."
);
}
}
std
::
vector
<
paddle
::
Tensor
>
ReluBackward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
if
(
x
.
is_cpu
())
{
return
relu_cpu_backward
(
x
,
out
,
grad_out
);
}
else
{
PD_THROW
(
"Not implemented."
);
}
}
std
::
vector
<
paddle
::
Tensor
>
ReluDoubleBackward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
if
(
out
.
place
()
==
paddle
::
PlaceType
::
kCPU
)
{
return
relu_cpu_double_backward
(
out
,
ddx
);
}
else
{
PD_THROW
(
"Not implemented."
);
}
}
std
::
vector
<
std
::
vector
<
int64_t
>>
ReluDoubleBackwardInferShape
(
const
std
::
vector
<
int64_t
>&
out_shape
,
const
std
::
vector
<
int64_t
>&
ddx_shape
)
{
return
{
out_shape
};
}
PD_BUILD_OP
(
custom_relu
)
.
Inputs
({
"X"
})
.
Outputs
({
"Out"
})
.
SetKernelFn
(
PD_KERNEL
(
ReluForward
));
PD_BUILD_GRAD_OP
(
custom_relu
)
.
Inputs
({
"X"
,
"Out"
,
paddle
::
Grad
(
"Out"
)})
.
Outputs
({
paddle
::
Grad
(
"X"
)})
.
SetKernelFn
(
PD_KERNEL
(
ReluBackward
));
PD_BUILD_DOUBLE_GRAD_OP
(
custom_relu
)
.
Inputs
({
"Out"
,
paddle
::
Grad
(
paddle
::
Grad
(
"X"
))})
.
Outputs
({
paddle
::
Grad
(
paddle
::
Grad
(
"Out"
))})
.
SetKernelFn
(
PD_KERNEL
(
ReluDoubleBackward
))
.
SetInferShapeFn
(
PD_INFER_SHAPE
(
ReluDoubleBackwardInferShape
));
// Extension with tensor operator overloading
paddle
::
Tensor
custom_sub2
(
paddle
::
Tensor
x
,
paddle
::
Tensor
y
)
{
return
paddle
::
exp
(
x
)
-
paddle
::
exp
(
y
);
}
// Extension with tensor operator overloading
paddle
::
Tensor
custom_add2
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
y
)
{
return
paddle
::
exp
(
x
)
+
paddle
::
exp
(
y
);
}
PYBIND11_MODULE
(
mix_relu_extension
,
m
)
{
m
.
def
(
"custom_add2"
,
&
custom_add2
,
"exp(x) + exp(y)"
);
m
.
def
(
"custom_sub2"
,
&
custom_sub2
,
"exp(x) - exp(y)"
);
}
python/paddle/fluid/tests/cpp_extension/mix_relu_and_extension_setup.py
0 → 100644
浏览文件 @
2135020a
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
utils
import
paddle_includes
from
paddle.utils.cpp_extension
import
CppExtension
,
setup
setup
(
name
=
'mix_relu_extension'
,
ext_modules
=
CppExtension
(
sources
=
[
"mix_relu_and_extension.cc"
,
"custom_sub.cc"
],
include_dirs
=
paddle_includes
+
[
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))],
extra_compile_args
=
{
'cc'
:
[
'-w'
,
'-g'
]},
verbose
=
True
,
),
)
python/paddle/fluid/tests/cpp_extension/test_cpp_extension_setup.py
浏览文件 @
2135020a
...
@@ -20,9 +20,80 @@ import unittest
...
@@ -20,9 +20,80 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.static
as
static
from
paddle.utils.cpp_extension.extension_utils
import
run_cmd
from
paddle.utils.cpp_extension.extension_utils
import
run_cmd
def
custom_relu_static
(
func
,
device
,
dtype
,
np_x
,
use_func
=
True
,
test_infer
=
False
):
paddle
.
enable_static
()
paddle
.
set_device
(
device
)
with
static
.
scope_guard
(
static
.
Scope
()):
with
static
.
program_guard
(
static
.
Program
()):
x
=
static
.
data
(
name
=
'X'
,
shape
=
[
None
,
8
],
dtype
=
dtype
)
x
.
stop_gradient
=
False
out
=
func
(
x
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
x
)
static
.
append_backward
(
out
)
exe
=
static
.
Executor
()
exe
.
run
(
static
.
default_startup_program
())
# in static graph mode, x data has been covered by out
out_v
=
exe
.
run
(
static
.
default_main_program
(),
feed
=
{
'X'
:
np_x
},
fetch_list
=
[
out
.
name
],
)
paddle
.
disable_static
()
return
out_v
def
custom_relu_dynamic
(
func
,
device
,
dtype
,
np_x
,
use_func
=
True
):
paddle
.
set_device
(
device
)
t
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
)
t
.
stop_gradient
=
False
out
=
func
(
t
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
t
)
out
.
stop_gradient
=
False
out
.
backward
()
if
t
.
grad
is
None
:
return
out
.
numpy
(),
t
.
grad
else
:
return
out
.
numpy
(),
t
.
grad
.
numpy
()
def
custom_relu_double_grad_dynamic
(
func
,
device
,
dtype
,
np_x
,
use_func
=
True
):
paddle
.
set_device
(
device
)
t
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
,
stop_gradient
=
False
)
t
.
retain_grads
()
out
=
func
(
t
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
t
)
out
.
retain_grads
()
dx
=
paddle
.
grad
(
outputs
=
out
,
inputs
=
t
,
grad_outputs
=
paddle
.
ones_like
(
t
),
create_graph
=
True
,
retain_graph
=
True
,
)
ddout
=
paddle
.
grad
(
outputs
=
dx
[
0
],
inputs
=
out
.
grad
,
grad_outputs
=
paddle
.
ones_like
(
t
),
create_graph
=
False
,
)
assert
ddout
[
0
].
numpy
()
is
not
None
return
dx
[
0
].
numpy
(),
ddout
[
0
].
numpy
()
class
TestCppExtensionSetupInstall
(
unittest
.
TestCase
):
class
TestCppExtensionSetupInstall
(
unittest
.
TestCase
):
"""
"""
Tests setup install cpp extensions.
Tests setup install cpp extensions.
...
@@ -30,23 +101,14 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
...
@@ -30,23 +101,14 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# install general extension
# compile, install the custom op egg into site-packages under background
# compile, install the custom op egg into site-packages under background
if
os
.
name
==
'nt'
:
cmd
=
'cd {} && {} cpp_extension_setup.py install'
.
format
(
cmd
=
'cd /d {} && python cpp_extension_setup.py install'
.
format
(
cur_dir
,
sys
.
executable
cur_dir
)
)
else
:
cmd
=
'cd {} && {} cpp_extension_setup.py install'
.
format
(
cur_dir
,
sys
.
executable
)
run_cmd
(
cmd
)
run_cmd
(
cmd
)
# os.system(cmd)
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
site_dir
=
site
.
getsitepackages
()[
0
]
if
os
.
name
==
'nt'
:
site_dir
=
site
.
getsitepackages
()[
1
]
else
:
site_dir
=
site
.
getsitepackages
()[
0
]
custom_egg_path
=
[
custom_egg_path
=
[
x
for
x
in
os
.
listdir
(
site_dir
)
if
'custom_cpp_extension'
in
x
x
for
x
in
os
.
listdir
(
site_dir
)
if
'custom_cpp_extension'
in
x
]
]
...
@@ -55,6 +117,22 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
...
@@ -55,6 +117,22 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
)
)
sys
.
path
.
append
(
os
.
path
.
join
(
site_dir
,
custom_egg_path
[
0
]))
sys
.
path
.
append
(
os
.
path
.
join
(
site_dir
,
custom_egg_path
[
0
]))
# install mixed custom_op and extension
cmd
=
'cd {} && {} mix_relu_and_extension_setup.py install'
.
format
(
cur_dir
,
sys
.
executable
)
run_cmd
(
cmd
)
site_dir
=
site
.
getsitepackages
()[
0
]
custom_egg_path
=
[
x
for
x
in
os
.
listdir
(
site_dir
)
if
'mix_relu_extension'
in
x
]
assert
len
(
custom_egg_path
)
==
1
,
"Matched egg number is %d."
%
len
(
custom_egg_path
)
sys
.
path
.
append
(
os
.
path
.
join
(
site_dir
,
custom_egg_path
[
0
]))
#################################
# config seed
# config seed
SEED
=
2021
SEED
=
2021
paddle
.
seed
(
SEED
)
paddle
.
seed
(
SEED
)
...
@@ -66,10 +144,16 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
...
@@ -66,10 +144,16 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
pass
pass
def
test_cpp_extension
(
self
):
def
test_cpp_extension
(
self
):
self
.
_test_extension_function
()
# Extension
self
.
_test_extension_function_plain
()
self
.
_test_extension_function_mixed
()
self
.
_test_extension_class
()
self
.
_test_extension_class
()
# Custom op
self
.
_test_static
()
self
.
_test_dynamic
()
self
.
_test_double_grad_dynamic
()
def
_test_extension_function
(
self
):
def
_test_extension_function
_plain
(
self
):
import
custom_cpp_extension
import
custom_cpp_extension
for
dtype
in
self
.
dtypes
:
for
dtype
in
self
.
dtypes
:
...
@@ -77,7 +161,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
...
@@ -77,7 +161,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
x
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
)
x
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
)
np_y
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
np_y
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
y
=
paddle
.
to_tensor
(
np_y
,
dtype
=
dtype
)
y
=
paddle
.
to_tensor
(
np_y
,
dtype
=
dtype
)
# Test custom_cpp_extension
out
=
custom_cpp_extension
.
custom_add
(
x
,
y
)
out
=
custom_cpp_extension
.
custom_add
(
x
,
y
)
target_out
=
np
.
exp
(
np_x
)
+
np
.
exp
(
np_y
)
target_out
=
np
.
exp
(
np_x
)
+
np
.
exp
(
np_y
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
target_out
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
target_out
,
atol
=
1e-5
)
...
@@ -87,10 +171,30 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
...
@@ -87,10 +171,30 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
target_out
=
np
.
exp
(
np_x
)
-
np
.
exp
(
np_y
)
target_out
=
np
.
exp
(
np_x
)
-
np
.
exp
(
np_y
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
target_out
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
target_out
,
atol
=
1e-5
)
def
_test_extension_function_mixed
(
self
):
import
mix_relu_extension
for
dtype
in
self
.
dtypes
:
np_x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
x
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
)
np_y
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
y
=
paddle
.
to_tensor
(
np_y
,
dtype
=
dtype
)
# Test mix_relu_extension
out
=
mix_relu_extension
.
custom_add2
(
x
,
y
)
target_out
=
np
.
exp
(
np_x
)
+
np
.
exp
(
np_y
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
target_out
,
atol
=
1e-5
)
# Test we can call a method not defined in the main C++ file.
out
=
mix_relu_extension
.
custom_sub2
(
x
,
y
)
target_out
=
np
.
exp
(
np_x
)
-
np
.
exp
(
np_y
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
target_out
,
atol
=
1e-5
)
def
_test_extension_class
(
self
):
def
_test_extension_class
(
self
):
import
custom_cpp_extension
import
custom_cpp_extension
for
dtype
in
self
.
dtypes
:
for
dtype
in
self
.
dtypes
:
# Test custom_cpp_extension
# Test we can use CppExtension class with C++ methods.
# Test we can use CppExtension class with C++ methods.
power
=
custom_cpp_extension
.
Power
(
3
,
3
)
power
=
custom_cpp_extension
.
Power
(
3
,
3
)
self
.
assertEqual
(
power
.
get
().
sum
(),
9
)
self
.
assertEqual
(
power
.
get
().
sum
(),
9
)
...
@@ -109,6 +213,77 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
...
@@ -109,6 +213,77 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
atol
=
1e-5
,
atol
=
1e-5
,
)
)
def
_test_static
(
self
):
import
mix_relu_extension
for
dtype
in
self
.
dtypes
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
out
=
custom_relu_static
(
mix_relu_extension
.
custom_relu
,
"CPU"
,
dtype
,
x
)
pd_out
=
custom_relu_static
(
mix_relu_extension
.
custom_relu
,
"CPU"
,
dtype
,
x
,
False
)
np
.
testing
.
assert_array_equal
(
out
,
pd_out
,
err_msg
=
'custom op out: {},
\n
paddle api out: {}'
.
format
(
out
,
pd_out
),
)
def
_test_dynamic
(
self
):
import
mix_relu_extension
for
dtype
in
self
.
dtypes
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
out
,
x_grad
=
custom_relu_dynamic
(
mix_relu_extension
.
custom_relu
,
"CPU"
,
dtype
,
x
)
pd_out
,
pd_x_grad
=
custom_relu_dynamic
(
mix_relu_extension
.
custom_relu
,
"CPU"
,
dtype
,
x
,
False
)
np
.
testing
.
assert_array_equal
(
out
,
pd_out
,
err_msg
=
'custom op out: {},
\n
paddle api out: {}'
.
format
(
out
,
pd_out
),
)
np
.
testing
.
assert_array_equal
(
x_grad
,
pd_x_grad
,
err_msg
=
'custom op x grad: {},
\n
paddle api x grad: {}'
.
format
(
x_grad
,
pd_x_grad
),
)
def
_test_double_grad_dynamic
(
self
):
import
mix_relu_extension
for
dtype
in
self
.
dtypes
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
out
,
dx_grad
=
custom_relu_double_grad_dynamic
(
mix_relu_extension
.
custom_relu
,
"CPU"
,
dtype
,
x
)
pd_out
,
pd_dx_grad
=
custom_relu_double_grad_dynamic
(
mix_relu_extension
.
custom_relu
,
"CPU"
,
dtype
,
x
,
False
)
np
.
testing
.
assert_array_equal
(
out
,
pd_out
,
err_msg
=
'custom op out: {},
\n
paddle api out: {}'
.
format
(
out
,
pd_out
),
)
np
.
testing
.
assert_array_equal
(
dx_grad
,
pd_dx_grad
,
err_msg
=
'custom op dx grad: {},
\n
paddle api dx grad: {}'
.
format
(
dx_grad
,
pd_dx_grad
),
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
os
.
name
==
'nt'
or
sys
.
platform
.
startswith
(
'darwin'
):
if
os
.
name
==
'nt'
or
sys
.
platform
.
startswith
(
'darwin'
):
...
...
python/paddle/fluid/tests/cpp_extension/utils.py
0 → 100644
浏览文件 @
2135020a
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
site
import
getsitepackages
from
paddle.utils.cpp_extension.extension_utils
import
IS_WINDOWS
IS_MAC
=
sys
.
platform
.
startswith
(
'darwin'
)
# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI.
# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find
# paddle include directory. Because the following path is generated after installing
# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI.
paddle_includes
=
[]
for
site_packages_path
in
getsitepackages
():
paddle_includes
.
append
(
os
.
path
.
join
(
site_packages_path
,
'paddle'
,
'include'
)
)
paddle_includes
.
append
(
os
.
path
.
join
(
site_packages_path
,
'paddle'
,
'include'
,
'third_party'
)
)
# Test for extra compile args
extra_cc_args
=
[
'-w'
,
'-g'
]
if
not
IS_WINDOWS
else
[
'/w'
]
extra_nvcc_args
=
[
'-O3'
]
extra_compile_args
=
{
'cc'
:
extra_cc_args
,
'nvcc'
:
extra_nvcc_args
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录