Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
db50fb67
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db50fb67
编写于
3月 03, 2021
作者:
Q
Qi Li
提交者:
GitHub
3月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] fix softmax with loss and update python scripts, test=develop (#31373)
上级
32211fe9
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
282 addition
and
48 deletion
+282
-48
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+112
-3
paddle/fluid/platform/for_range.h
paddle/fluid/platform/for_range.h
+5
-0
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+9
-2
python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py
python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py
+12
-2
python/paddle/fluid/tests/unittests/test_conv2d_op.py
python/paddle/fluid/tests/unittests/test_conv2d_op.py
+4
-2
python/paddle/fluid/tests/unittests/test_pool2d_op.py
python/paddle/fluid/tests/unittests/test_pool2d_op.py
+5
-1
python/paddle/fluid/tests/unittests/test_softmax_op.py
python/paddle/fluid/tests/unittests/test_softmax_op.py
+9
-3
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
...uid/tests/unittests/test_softmax_with_cross_entropy_op.py
+33
-23
python/paddle/utils/cpp_extension/cpp_extension.py
python/paddle/utils/cpp_extension/cpp_extension.py
+21
-8
python/paddle/utils/cpp_extension/extension_utils.py
python/paddle/utils/cpp_extension/extension_utils.py
+72
-4
未找到文件。
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
db50fb67
...
...
@@ -8,7 +8,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cub/cub.cuh>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
...
...
@@ -214,6 +220,60 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
0
;
}
#ifdef __HIPCC__ // @{ HIP Seperate Kernel for RowReductionForDiffMaxSum
// Note(qili93): HIP do not support return in kernel, need to seperate
// RowReductionForDiffMaxSum into two kernels below
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
RowReductionForSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int64_t
d
,
int
axis_dim
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
int64_t
remain
=
d
/
axis_dim
;
int64_t
idx_n
=
blockIdx
.
x
/
remain
;
int64_t
idx_remain
=
blockIdx
.
x
%
remain
;
int64_t
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int64_t
end_idx
=
(
idx_n
+
1
)
*
d
;
auto
block_max
=
max_data
[
blockIdx
.
x
];
int64_t
step
=
BlockDim
*
remain
;
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
T
diff_max_sum
=
exp_on_device
(
softmax
[
beg_idx
]);
auto
idx
=
beg_idx
+
step
;
while
(
idx
<
end_idx
)
{
softmax
[
idx
]
=
logits_data
[
idx
]
-
block_max
;
diff_max_sum
+=
exp_on_device
(
softmax
[
idx
]);
idx
+=
step
;
}
diff_max_sum
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
diff_max_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
log_on_device
(
diff_max_sum
);
}
template
<
typename
T
,
int
BlockDim
,
bool
CalculateLogSoftmax
=
false
>
static
__global__
void
RowReductionForDiff
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
d
,
int
axis_dim
)
{
int
remain
=
d
/
axis_dim
;
int
idx_n
=
blockIdx
.
x
/
remain
;
int
idx_remain
=
blockIdx
.
x
%
remain
;
int
beg_idx
=
idx_n
*
d
+
threadIdx
.
x
*
remain
+
idx_remain
;
int
end_idx
=
(
idx_n
+
1
)
*
d
;
int
step
=
BlockDim
*
remain
;
T
diff_max_sum
=
max_data
[
blockIdx
.
x
];
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
step
;
while
(
beg_idx
<
end_idx
)
{
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
step
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
0
;
}
#endif // @} End HIP Seperate Kernel for RowReductionForDiffMaxSum
// Make sure that BlockDim <= axis_dim
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
...
...
@@ -345,6 +405,28 @@ static void HardLabelSoftmaxWithCrossEntropy(
int64_t
grid_dim
=
n
*
d
/
axis_dim
;
auto
stream
=
ctx
.
stream
();
#ifdef __HIPCC__
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, d, axis_dim); \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, softmax_data, d, axis_dim); \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForDiff<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, softmax_data, d, axis_dim); \
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d); \
if (ignore_idx >= 0 && ignore_idx < axis_dim) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} \
} break
#else
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
...
...
@@ -361,6 +443,7 @@ static void HardLabelSoftmaxWithCrossEntropy(
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} \
} break
#endif
switch
(
block_dim
)
{
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
512
);
...
...
@@ -383,13 +466,27 @@ static void HardLabelSoftmaxWithCrossEntropy(
template
<
typename
T
>
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
softmax_data
,
T
*
loss_data
,
int64_t
n
,
int64_t
d
,
int
axis_dim
,
cuda
Stream_t
stream
)
{
int64_t
n
,
int64_t
d
,
int
axis_dim
,
gpu
Stream_t
stream
)
{
constexpr
int
kMaxBlockDim
=
512
;
int64_t
block_dim
=
axis_dim
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
axis_dim
)));
int64_t
grid_dim
=
n
*
d
/
axis_dim
;
#ifdef __HIPCC__
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, d, axis_dim); \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, softmax_data, d, axis_dim); \
hipLaunchKernelGGL( \
HIP_KERNEL_NAME(RowReductionForSoftmaxAndCrossEntropy<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, labels_data, \
loss_data, softmax_data, d, axis_dim); \
break
#else
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
...
...
@@ -400,6 +497,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(
T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, labels_data, loss_data, softmax_data, d, axis_dim); \
break
#endif
switch
(
block_dim
)
{
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
512
);
...
...
@@ -536,6 +634,16 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyCUDAKernel
<
float
>
,
ops
::
SoftmaxWithCrossEntropyCUDAKernel
<
paddle
::
platform
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
softmax_with_cross_entropy_grad
,
ops
::
SoftmaxWithCrossEntropyGradCUDAKernel
<
float
>
,
ops
::
SoftmaxWithCrossEntropyGradCUDAKernel
<
paddle
::
platform
::
float16
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyCUDAKernel
<
float
>
,
ops
::
SoftmaxWithCrossEntropyCUDAKernel
<
paddle
::
platform
::
float16
>
,
...
...
@@ -545,3 +653,4 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
SoftmaxWithCrossEntropyGradCUDAKernel
<
float
>
,
ops
::
SoftmaxWithCrossEntropyGradCUDAKernel
<
paddle
::
platform
::
float16
>
,
ops
::
SoftmaxWithCrossEntropyGradCUDAKernel
<
double
>
);
#endif
paddle/fluid/platform/for_range.h
浏览文件 @
db50fb67
...
...
@@ -62,7 +62,12 @@ struct ForRange<CUDADeviceContext> {
template
<
typename
Function
>
inline
void
operator
()(
Function
func
)
const
{
#ifdef __HIPCC__
// HIP will throw core dump when threads > 256
constexpr
int
num_threads
=
256
;
#else
constexpr
int
num_threads
=
1024
;
#endif
size_t
block_size
=
limit_
<=
num_threads
?
limit_
:
num_threads
;
size_t
grid_size
=
(
limit_
+
num_threads
-
1
)
/
num_threads
;
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
db50fb67
...
...
@@ -91,7 +91,11 @@ class TestParameter(object):
x
=
fluid
.
dygraph
.
to_variable
(
np_x
)
z
=
eval
(
"paddle.%s(x).numpy()"
%
self
.
op_type
)
z_expected
=
eval
(
"np.%s(np_x)"
%
self
.
op_type
)
self
.
assertEqual
(
z
,
z_expected
)
# ROCM platform will fail in assertEqual
if
core
.
is_compiled_with_rocm
():
self
.
assertTrue
(
np
.
allclose
(
z
,
z_expected
))
else
:
self
.
assertEqual
(
z
,
z_expected
)
class
TestSigmoid
(
TestActivation
):
...
...
@@ -2651,7 +2655,10 @@ create_test_act_fp16_class(TestSoftRelu)
create_test_act_fp16_class
(
TestELU
)
create_test_act_fp16_class
(
TestReciprocal
)
create_test_act_fp16_class
(
TestLog
)
create_test_act_fp16_class
(
TestLog2
,
atol
=
5e-2
)
if
core
.
is_compiled_with_rocm
():
create_test_act_fp16_class
(
TestLog2
,
atol
=
5e-2
,
grad_atol
=
0.85
)
else
:
create_test_act_fp16_class
(
TestLog2
,
atol
=
5e-2
)
create_test_act_fp16_class
(
TestLog10
,
atol
=
5e-2
)
create_test_act_fp16_class
(
TestLog1p
,
grad_atol
=
0.9
)
create_test_act_fp16_class
(
TestSquare
)
...
...
python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py
浏览文件 @
db50fb67
...
...
@@ -171,7 +171,11 @@ class TestBatchNorm(unittest.TestCase):
class
TestBatchNormChannelLast
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
original_dtyep
=
paddle
.
get_default_dtype
()
paddle
.
set_default_dtype
(
"float64"
)
# MIOPEN not support data type of double
if
core
.
is_compiled_with_rocm
():
paddle
.
set_default_dtype
(
"float32"
)
else
:
paddle
.
set_default_dtype
(
"float64"
)
self
.
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"batch_norm"
):
self
.
places
.
append
(
fluid
.
CUDAPlace
(
0
))
...
...
@@ -219,7 +223,13 @@ class TestBatchNormChannelLast(unittest.TestCase):
channel_first_x
=
paddle
.
transpose
(
x
,
[
0
,
4
,
1
,
2
,
3
])
y2
=
net2
(
channel_first_x
)
y2
=
paddle
.
transpose
(
y2
,
[
0
,
2
,
3
,
4
,
1
])
self
.
assertEqual
(
np
.
allclose
(
y1
.
numpy
(),
y2
.
numpy
()),
True
)
if
core
.
is_compiled_with_rocm
():
# HIP will fail if no atol
self
.
assertEqual
(
np
.
allclose
(
y1
.
numpy
(),
y2
.
numpy
(),
atol
=
1e-07
),
True
)
else
:
self
.
assertEqual
(
np
.
allclose
(
y1
.
numpy
(),
y2
.
numpy
()),
True
)
class
TestBatchNormUseGlobalStats
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/unittests/test_conv2d_op.py
浏览文件 @
db50fb67
...
...
@@ -298,7 +298,8 @@ class TestConv2DOp(OpTest):
self
.
use_mkldnn
=
False
self
.
fuse_relu_before_depthwise_conv
=
False
self
.
data_format
=
"AnyLayout"
self
.
dtype
=
np
.
float64
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self
.
dtype
=
np
.
float32
if
core
.
is_compiled_with_rocm
()
else
np
.
float64
self
.
init_kernel_type
()
self
.
init_group
()
self
.
init_dilation
()
...
...
@@ -732,7 +733,8 @@ class TestConv2DOp_v2(OpTest):
self
.
use_cuda
=
False
self
.
use_mkldnn
=
False
self
.
fuse_relu_before_depthwise_conv
=
False
self
.
dtype
=
np
.
float64
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self
.
dtype
=
np
.
float32
if
core
.
is_compiled_with_rocm
()
else
np
.
float64
self
.
init_kernel_type
()
self
.
init_group
()
self
.
init_dilation
()
...
...
python/paddle/fluid/tests/unittests/test_pool2d_op.py
浏览文件 @
db50fb67
...
...
@@ -41,6 +41,8 @@ def max_pool2D_forward_naive(x,
exclusive
=
True
,
adaptive
=
False
,
data_type
=
np
.
float64
):
if
data_type
==
np
.
float64
and
core
.
is_compiled_with_rocm
():
data_type
=
np
.
float32
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
...
...
@@ -81,6 +83,8 @@ def avg_pool2D_forward_naive(x,
exclusive
=
True
,
adaptive
=
False
,
data_type
=
np
.
float64
):
if
data_type
==
np
.
float64
and
core
.
is_compiled_with_rocm
():
data_type
=
np
.
float32
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
...
...
@@ -340,7 +344,7 @@ class TestPool2D_Op(OpTest):
self
.
use_cudnn
=
False
def
init_data_type
(
self
):
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
def
init_pool_type
(
self
):
self
.
pool_type
=
"avg"
...
...
python/paddle/fluid/tests/unittests/test_softmax_op.py
浏览文件 @
db50fb67
...
...
@@ -55,7 +55,8 @@ class TestSoftmaxOp(OpTest):
self
.
op_type
=
"softmax"
self
.
use_cudnn
=
False
self
.
use_mkldnn
=
False
self
.
dtype
=
np
.
float64
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self
.
dtype
=
np
.
float32
if
core
.
is_compiled_with_rocm
()
else
np
.
float64
self
.
init_kernel_type
()
self
.
shape
=
self
.
get_x_shape
()
self
.
axis
=
self
.
get_axis
()
...
...
@@ -338,8 +339,13 @@ class TestSoftmaxAPI(unittest.TestCase):
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
out
=
self
.
softmax
(
x
,
dtype
=
np
.
float64
)
out_ref
=
ref_softmax
(
self
.
x_np
,
axis
=-
1
,
dtype
=
np
.
float64
)
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
if
core
.
is_compiled_with_rocm
():
out
=
self
.
softmax
(
x
,
dtype
=
np
.
float32
)
out_ref
=
ref_softmax
(
self
.
x_np
,
axis
=-
1
,
dtype
=
np
.
float32
)
else
:
out
=
self
.
softmax
(
x
,
dtype
=
np
.
float64
)
out_ref
=
ref_softmax
(
self
.
x_np
,
axis
=-
1
,
dtype
=
np
.
float64
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
out
.
numpy
()),
True
)
paddle
.
enable_static
()
...
...
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
浏览文件 @
db50fb67
...
...
@@ -51,7 +51,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
False
self
.
soft_label
=
False
self
.
dtype
=
np
.
float64
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
self
.
dtype
=
np
.
float32
if
core
.
is_compiled_with_rocm
()
else
np
.
float64
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
shape
=
[
41
,
37
]
...
...
@@ -93,7 +94,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"Logits"
],
"Loss"
,
max_relative_error
=
5e-5
)
if
core
.
is_compiled_with_rocm
():
# HIP will have accuracy fail when using float32 in CPU place
self
.
check_grad
([
"Logits"
],
"Loss"
,
max_relative_error
=
5e-1
)
else
:
self
.
check_grad
([
"Logits"
],
"Loss"
,
max_relative_error
=
5e-5
)
class
TestSoftmaxWithCrossEntropyOpNoCudnn
(
TestSoftmaxWithCrossEntropyOp
):
...
...
@@ -104,7 +109,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
...
...
@@ -124,9 +129,10 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
self
.
op_type
=
"softmax_with_cross_entropy"
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
date_type
=
np
.
float32
if
core
.
is_compiled_with_rocm
()
else
np
.
float64
logits
=
getattr
(
self
,
"logits"
,
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
shape
).
astype
(
np
.
float64
))
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
shape
).
astype
(
date_type
))
softmax
=
np
.
apply_along_axis
(
stable_softmax
,
self
.
axis
,
logits
)
axis_dim
=
self
.
shape
[
self
.
axis
]
...
...
@@ -178,7 +184,7 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
True
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
shape
=
[
41
,
37
]
...
...
@@ -187,7 +193,11 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"Logits"
],
"Loss"
)
if
core
.
is_compiled_with_rocm
():
# HIP will have accuracy fail when using float32 in CPU place
self
.
check_grad
([
"Logits"
],
"Loss"
,
max_relative_error
=
0.1
)
else
:
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOp3
(
TestSoftmaxWithCrossEntropyOp
):
...
...
@@ -202,7 +212,7 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):
self
.
shape
=
[
41
,
37
]
self
.
ignore_index
=
5
self
.
axis
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOp3NoCudnn
(
TestSoftmaxWithCrossEntropyOp3
):
...
...
@@ -213,7 +223,7 @@ class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
ignore_index
=
4
self
.
axis
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpAxis1
(
TestSoftmaxWithCrossEntropyOp
):
...
...
@@ -226,7 +236,7 @@ class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
False
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
axis
=
0
self
.
ignore_index
=
-
1
self
.
shape
=
[
3
,
5
,
7
,
11
]
...
...
@@ -242,7 +252,7 @@ class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
False
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
axis
=
1
self
.
ignore_index
=
-
1
self
.
shape
=
[
3
,
5
,
7
,
11
]
...
...
@@ -258,7 +268,7 @@ class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
False
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
axis
=
2
self
.
ignore_index
=
-
1
self
.
shape
=
[
3
,
5
,
7
,
11
]
...
...
@@ -274,7 +284,7 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
False
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
axis
=
3
self
.
ignore_index
=
-
1
self
.
shape
=
[
3
,
5
,
7
,
11
]
...
...
@@ -291,7 +301,7 @@ class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
False
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
shape
=
[
3
,
5
,
7
,
1
]
...
...
@@ -342,7 +352,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
0
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpSoftLabelAxis2
(
...
...
@@ -354,7 +364,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
1
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpSoftLabelAxis3
(
...
...
@@ -366,7 +376,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
2
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpSoftLabelAxis4
(
...
...
@@ -378,7 +388,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
3
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1
(
...
...
@@ -390,7 +400,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
ignore_index
=
1
self
.
axis
=
0
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2
(
...
...
@@ -402,7 +412,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
ignore_index
=
0
self
.
axis
=
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3
(
...
...
@@ -414,7 +424,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
ignore_index
=
3
self
.
axis
=
2
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4
(
...
...
@@ -426,7 +436,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
ignore_index
=
3
self
.
axis
=
3
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
class
TestSoftmaxWithCrossEntropyOpBoundary0
(
TestSoftmaxWithCrossEntropyOp
):
...
...
@@ -442,7 +452,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp):
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
logits
=
np
.
full
(
self
.
shape
,
-
500.0
).
astype
(
self
.
dtype
)
...
...
@@ -459,7 +469,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float64
self
.
dtype
=
np
.
float
32
if
core
.
is_compiled_with_rocm
()
else
np
.
float
64
self
.
logits
=
np
.
full
(
self
.
shape
,
1000.0
).
astype
(
self
.
dtype
)
self
.
logits
[:,
:,
0
,
:]
=
-
1000.0
...
...
python/paddle/utils/cpp_extension/cpp_extension.py
浏览文件 @
db50fb67
...
...
@@ -22,7 +22,7 @@ from setuptools.command.easy_install import easy_install
from
setuptools.command.build_ext
import
build_ext
from
distutils.command.build
import
build
from
.extension_utils
import
find_cuda_home
,
normalize_extension_kwargs
,
add_compile_flag
from
.extension_utils
import
find_cuda_home
,
find_rocm_home
,
normalize_extension_kwargs
,
add_compile_flag
from
.extension_utils
import
is_cuda_file
,
prepare_unix_cudaflags
,
prepare_win_cudaflags
from
.extension_utils
import
_import_module_from_library
,
_write_setup_file
,
_jit_compile
from
.extension_utils
import
check_abi_compatibility
,
log_v
,
CustomOpInfo
,
parse_op_name_from
...
...
@@ -31,6 +31,8 @@ from .extension_utils import bootstrap_context, get_build_directory, add_std_wit
from
.extension_utils
import
IS_WINDOWS
,
OS_NAME
,
MSVC_COMPILE_FLAGS
,
MSVC_COMPILE_FLAGS
from
...fluid
import
core
# Note(zhouwei): On windows, it will export function 'PyInit_[name]' by default,
# The solution is: 1.User add function PyInit_[name] 2. set not to export
# refer to https://stackoverflow.com/questions/34689210/error-exporting-symbol-when-building-python-c-extension-in-windows
...
...
@@ -39,7 +41,10 @@ if IS_WINDOWS and six.PY3:
from
unittest.mock
import
Mock
_du_build_ext
.
get_export_symbols
=
Mock
(
return_value
=
None
)
CUDA_HOME
=
find_cuda_home
()
if
core
.
is_compiled_with_rocm
():
ROCM_HOME
=
find_rocm_home
()
else
:
CUDA_HOME
=
find_cuda_home
()
def
setup
(
**
attr
):
...
...
@@ -394,12 +399,20 @@ class BuildExtension(build_ext, object):
original_compiler
=
self
.
compiler
.
compiler_so
# ncvv compile CUDA source
if
is_cuda_file
(
src
):
assert
CUDA_HOME
is
not
None
nvcc_cmd
=
os
.
path
.
join
(
CUDA_HOME
,
'bin'
,
'nvcc'
)
self
.
compiler
.
set_executable
(
'compiler_so'
,
nvcc_cmd
)
# {'nvcc': {}, 'cxx: {}}
if
isinstance
(
cflags
,
dict
):
cflags
=
cflags
[
'nvcc'
]
if
core
.
is_compiled_with_rocm
():
assert
ROCM_HOME
is
not
None
hipcc_cmd
=
os
.
path
.
join
(
ROCM_HOME
,
'bin'
,
'hipcc'
)
self
.
compiler
.
set_executable
(
'compiler_so'
,
hipcc_cmd
)
# {'nvcc': {}, 'cxx: {}}
if
isinstance
(
cflags
,
dict
):
cflags
=
cflags
[
'hipcc'
]
else
:
assert
CUDA_HOME
is
not
None
nvcc_cmd
=
os
.
path
.
join
(
CUDA_HOME
,
'bin'
,
'nvcc'
)
self
.
compiler
.
set_executable
(
'compiler_so'
,
nvcc_cmd
)
# {'nvcc': {}, 'cxx: {}}
if
isinstance
(
cflags
,
dict
):
cflags
=
cflags
[
'nvcc'
]
cflags
=
prepare_unix_cudaflags
(
cflags
)
# cxx compile Cpp source
...
...
python/paddle/utils/cpp_extension/extension_utils.py
浏览文件 @
db50fb67
...
...
@@ -464,6 +464,39 @@ def find_cuda_home():
return
cuda_home
def
find_rocm_home
():
"""
Use heuristic method to find rocm path
"""
# step 1. find in $ROCM_HOME or $ROCM_PATH
rocm_home
=
os
.
environ
.
get
(
'ROCM_HOME'
)
or
os
.
environ
.
get
(
'ROCM_PATH'
)
# step 2. find path by `which nvcc`
if
rocm_home
is
None
:
which_cmd
=
'where'
if
IS_WINDOWS
else
'which'
try
:
with
open
(
os
.
devnull
,
'w'
)
as
devnull
:
hipcc_path
=
subprocess
.
check_output
(
[
which_cmd
,
'hipcc'
],
stderr
=
devnull
)
if
six
.
PY3
:
hipcc_path
=
hipcc_path
.
decode
()
hipcc_path
=
hipcc_path
.
rstrip
(
'
\r\n
'
)
# for example: /opt/rocm/bin/hipcc
rocm_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
hipcc_path
))
except
:
rocm_home
=
"/opt/rocm"
# step 3. check whether path is valid
if
rocm_home
and
not
os
.
path
.
exists
(
rocm_home
)
and
core
.
is_compiled_with_rocm
():
rocm_home
=
None
warnings
.
warn
(
"Not found ROCM runtime, please use `export ROCM_PATH= XXX` to specific it."
)
return
rocm_home
def
find_cuda_includes
():
"""
Use heuristic method to find cuda include path
...
...
@@ -477,6 +510,19 @@ def find_cuda_includes():
return
[
os
.
path
.
join
(
cuda_home
,
'include'
)]
def
find_rocm_includes
():
"""
Use heuristic method to find rocm include path
"""
rocm_home
=
find_rocm_home
()
if
rocm_home
is
None
:
raise
ValueError
(
"Not found ROCM runtime, please use `export ROCM_PATH= XXX` to specific it."
)
return
[
os
.
path
.
join
(
rocm_home
,
'include'
)]
def
find_paddle_includes
(
use_cuda
=
False
):
"""
Return Paddle necessary include dir path.
...
...
@@ -487,8 +533,12 @@ def find_paddle_includes(use_cuda=False):
include_dirs
=
[
paddle_include_dir
,
third_party_dir
]
if
use_cuda
:
cuda_include_dir
=
find_cuda_includes
()
include_dirs
.
extend
(
cuda_include_dir
)
if
core
.
is_compiled_with_rocm
():
rocm_include_dir
=
find_rocm_includes
()
include_dirs
.
extend
(
rocm_include_dir
)
else
:
cuda_include_dir
=
find_cuda_includes
()
include_dirs
.
extend
(
cuda_include_dir
)
return
include_dirs
...
...
@@ -510,6 +560,20 @@ def find_cuda_libraries():
return
cuda_lib_dir
def
find_rocm_libraries
():
"""
Use heuristic method to find rocm dynamic lib path
"""
rocm_home
=
find_rocm_home
()
if
rocm_home
is
None
:
raise
ValueError
(
"Not found ROCM runtime, please use `export ROCM_PATH=XXX` to specific it."
)
rocm_lib_dir
=
[
os
.
path
.
join
(
rocm_home
,
'lib'
)]
return
rocm_lib_dir
def
find_paddle_libraries
(
use_cuda
=
False
):
"""
Return Paddle necessary library dir path.
...
...
@@ -518,8 +582,12 @@ def find_paddle_libraries(use_cuda=False):
paddle_lib_dirs
=
[
get_lib
()]
if
use_cuda
:
cuda_lib_dir
=
find_cuda_libraries
()
paddle_lib_dirs
.
extend
(
cuda_lib_dir
)
if
core
.
is_compiled_with_rocm
():
rocm_lib_dir
=
find_rocm_libraries
()
paddle_lib_dirs
.
extend
(
rocm_lib_dir
)
else
:
cuda_lib_dir
=
find_cuda_libraries
()
paddle_lib_dirs
.
extend
(
cuda_lib_dir
)
return
paddle_lib_dirs
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录