Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9cd99f7e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9cd99f7e
编写于
3月 14, 2023
作者:
I
Infinity_lee
提交者:
GitHub
3月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【hackathon 4 No53】label_smooth add fp16 support (#51493)
上级
775fb43a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
43 addition
and
19 deletion
+43
-19
paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
+8
-5
paddle/phi/kernels/gpu/label_smooth_kernel.cu
paddle/phi/kernels/gpu/label_smooth_kernel.cu
+23
-11
python/paddle/fluid/tests/unittests/test_label_smooth_op.py
python/paddle/fluid/tests/unittests/test_label_smooth_op.py
+10
-1
python/paddle/nn/functional/common.py
python/paddle/nn/functional/common.py
+2
-2
未找到文件。
paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
浏览文件 @
9cd99f7e
...
...
@@ -15,20 +15,22 @@
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace
phi
{
template
<
typename
T
>
struct
LabelSmoothGradFunctor
{
T
epsilon
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
epsilon
;
__forceinline__
LabelSmoothGradFunctor
(
float
epsilon_data
)
{
epsilon
=
static_cast
<
T
>
(
epsilon_data
);
epsilon
=
static_cast
<
MPType
>
(
epsilon_data
);
}
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
static_cast
<
T
>
(
1
-
epsilon
)
*
x
;
return
static_cast
<
T
>
((
static_cast
<
MPType
>
(
1
)
-
epsilon
)
*
static_cast
<
MPType
>
(
x
));
}
};
...
...
@@ -52,4 +54,5 @@ PD_REGISTER_KERNEL(label_smooth_grad,
ALL_LAYOUT
,
phi
::
LabelSmoothGradKernel
,
float
,
double
)
{}
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/label_smooth_kernel.cu
浏览文件 @
9cd99f7e
...
...
@@ -17,24 +17,27 @@
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace
phi
{
template
<
typename
T
>
struct
LabelSmoothFunctor
{
T
epsilon
;
T
label_dim
;
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MPType
epsilon
;
MPType
label_dim
;
__forceinline__
LabelSmoothFunctor
(
float
epsilon_data
,
int
label_dim_data
)
{
epsilon
=
static_cast
<
T
>
(
epsilon_data
);
label_dim
=
static_cast
<
T
>
(
label_dim_data
);
epsilon
=
static_cast
<
MPType
>
(
epsilon_data
);
label_dim
=
static_cast
<
MPType
>
(
label_dim_data
);
}
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
(
static_cast
<
T
>
(
1
-
epsilon
)
*
x
+
static_cast
<
T
>
(
epsilon
/
label_dim
));
return
static_cast
<
T
>
(
static_cast
<
MPType
>
(
static_cast
<
MPType
>
(
1
)
-
epsilon
)
*
static_cast
<
MPType
>
(
x
)
+
static_cast
<
MPType
>
(
epsilon
/
label_dim
));
}
};
...
...
@@ -45,10 +48,14 @@ __global__ void LabelSmoothRunDistKernel(const int N,
const
T
*
src
,
const
T
*
dist_data
,
T
*
dst
)
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
CUDA_KERNEL_LOOP
(
idx
,
N
)
{
int
dist_idx
=
idx
%
dist_numel
;
dst
[
idx
]
=
static_cast
<
T
>
(
1
-
epsilon
)
*
src
[
idx
]
+
static_cast
<
T
>
(
epsilon
)
*
dist_data
[
dist_idx
];
dst
[
idx
]
=
static_cast
<
T
>
((
static_cast
<
MPType
>
(
1
)
-
static_cast
<
MPType
>
(
epsilon
))
*
static_cast
<
MPType
>
(
src
[
idx
])
+
static_cast
<
MPType
>
(
epsilon
)
*
static_cast
<
MPType
>
(
dist_data
[
dist_idx
]));
}
}
...
...
@@ -83,5 +90,10 @@ void LabelSmoothKernel(const Context& ctx,
}
// namespace phi
PD_REGISTER_KERNEL
(
label_smooth
,
GPU
,
ALL_LAYOUT
,
phi
::
LabelSmoothKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
label_smooth
,
GPU
,
ALL_LAYOUT
,
phi
::
LabelSmoothKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
python/paddle/fluid/tests/unittests/test_label_smooth_op.py
浏览文件 @
9cd99f7e
...
...
@@ -24,9 +24,10 @@ class TestLabelSmoothOp(OpTest):
def
config
(
self
):
self
.
op_type
=
"label_smooth"
self
.
python_api
=
paddle
.
nn
.
functional
.
label_smooth
self
.
init_dtype
()
self
.
epsilon
=
0.1
batch_size
,
self
.
label_dim
=
10
,
12
self
.
label
=
np
.
zeros
((
batch_size
,
self
.
label_dim
)).
astype
(
"float64"
)
self
.
label
=
np
.
zeros
((
batch_size
,
self
.
label_dim
)).
astype
(
self
.
dtype
)
nonzero_index
=
np
.
random
.
randint
(
self
.
label_dim
,
size
=
(
batch_size
))
self
.
label
[
np
.
arange
(
batch_size
),
nonzero_index
]
=
1
...
...
@@ -39,6 +40,9 @@ class TestLabelSmoothOp(OpTest):
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
}
self
.
outputs
=
{
'Out'
:
smoothed_label
}
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float64
def
test_check_output
(
self
):
self
.
check_output
(
check_eager
=
True
)
...
...
@@ -46,6 +50,11 @@ class TestLabelSmoothOp(OpTest):
self
.
check_grad
([
"X"
],
"Out"
,
check_eager
=
True
)
class
TestLabelSmoothFP16OP
(
TestLabelSmoothOp
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestLabelSmoothOpWithPriorDist
(
TestLabelSmoothOp
):
def
setUp
(
self
):
self
.
config
()
...
...
python/paddle/nn/functional/common.py
浏览文件 @
9cd99f7e
...
...
@@ -1923,7 +1923,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
label(Tensor): The input variable containing the label data. The
label data should use one-hot representation. It's
a multidimensional tensor with a shape of
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float32" and "float64".
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float
16" "float
32" and "float64".
prior_dist(Tensor, optional): The prior distribution to be used to smooth
labels. If not provided, an uniform distribution
is used. It's a multidimensional tensor with a shape of
...
...
@@ -1965,7 +1965,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
)
check_variable_and_dtype
(
label
,
'label'
,
[
'float32'
,
'float64'
],
'label_smooth'
label
,
'label'
,
[
'float
16'
,
'float
32'
,
'float64'
],
'label_smooth'
)
helper
=
LayerHelper
(
"label_smooth"
,
**
locals
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录