Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
649948a6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
649948a6
编写于
3月 29, 2022
作者:
Z
zhangyikun02
提交者:
GitHub
3月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
softmax_with_cross_entropy support fp16 on xpu, test=kunlun (#40869)
上级
3b381aac
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
36 addition
and
20 deletion
+36
-20
paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc
paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc
+34
-19
paddle/fluid/platform/device/xpu/xpu2_op_list.h
paddle/fluid/platform/device/xpu/xpu2_op_list.h
+2
-1
未找到文件。
paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc
浏览文件 @
649948a6
...
...
@@ -28,6 +28,8 @@ namespace operators {
template
<
typename
T
>
class
SoftmaxWithCrossEntropyXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
...
...
@@ -48,6 +50,10 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
logits_dims
=
phi
::
vectorize
<
int
>
(
logits
->
dims
());
const
bool
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
auto
logits_data
=
reinterpret_cast
<
const
XPUType
*>
(
logits
->
data
<
T
>
());
auto
softmax_data
=
reinterpret_cast
<
XPUType
*>
(
softmax
->
data
<
T
>
());
auto
loss_data
=
reinterpret_cast
<
XPUType
*>
(
loss
->
data
<
T
>
());
// softmax
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
XPUDeviceContext
>();
...
...
@@ -55,32 +61,41 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
if
(
platform
::
get_xpu_version
(
context
.
GetPlace
().
GetDeviceId
())
==
phi
::
backends
::
xpu
::
XPUVersion
::
XPU2
&&
soft_label
)
{
r
=
xpu
::
soft_softmax_with_cross_entropy
(
dev_ctx
.
x_context
(),
logits
->
data
<
float
>
(),
labels
->
data
<
T
>
(),
softmax
->
data
<
T
>
(),
loss
->
data
<
T
>
(),
n
,
d
);
auto
labels_data
=
reinterpret_cast
<
const
XPUType
*>
(
labels
->
data
<
T
>
());
r
=
xpu
::
soft_softmax_with_cross_entropy
<
XPUType
>
(
dev_ctx
.
x_context
(),
logits_data
,
labels_data
,
softmax_data
,
loss_data
,
n
,
d
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"soft_softmax_with_cross_entropy"
);
return
;
}
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
int
len
=
logits
->
numel
();
T
*
clip_logits_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
T
>
(
len
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
clip_logits_data
);
T
*
clip_logits
=
RAII_GUARD
.
alloc_l3_or_gm
<
T
>
(
len
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
clip_logits
);
XPUType
*
clip_logits_data
=
reinterpret_cast
<
XPUType
*>
(
clip_logits
);
float
max_val
=
1e20
;
float
min_val
=
-
1e20
;
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
max_val
=
65504
;
min_val
=
-
65504
;
}
r
=
xpu
::
clip_v2
(
dev_ctx
.
x_context
(),
logits
->
data
<
float
>
(),
clip_logits_data
,
len
,
static_cast
<
float
>
(
-
1e20
)
,
static_cast
<
float
>
(
1e20
));
r
=
xpu
::
clip_v2
<
XPUType
>
(
dev_ctx
.
x_context
(),
logits_data
,
clip_logits_data
,
len
,
static_cast
<
XPUType
>
(
min_val
),
static_cast
<
XPUType
>
(
max_val
));
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"clip_v2"
);
r
=
xpu
::
softmax
(
dev_ctx
.
x_context
(),
clip_logits_data
,
softmax
->
data
<
float
>
()
,
logits_dims
,
axis
);
r
=
xpu
::
softmax
<
XPUType
>
(
dev_ctx
.
x_context
(),
clip_logits_data
,
softmax_data
,
logits_dims
,
axis
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"softmax"
);
// cross_entropy
if
(
soft_label
)
{
r
=
xpu
::
soft_cross_entropy
<
float
>
(
dev_ctx
.
x_context
(),
softmax
->
data
<
float
>
(),
labels
->
data
<
float
>
()
,
loss
->
data
<
float
>
()
,
n
,
d
);
auto
labels_data
=
reinterpret_cast
<
const
XPUType
*>
(
labels
->
data
<
T
>
());
r
=
xpu
::
soft_cross_entropy
<
XPUType
>
(
dev_ctx
.
x_context
(),
softmax_data
,
labels_data
,
loss_data
,
n
,
d
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"soft_cross_entropy"
);
}
else
{
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
...
...
@@ -92,10 +107,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
labels_int32
.
data
<
int32_t
>
(),
labels
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"clip_v2"
);
r
=
xpu
::
hard_cross_entropy
<
float
,
int32_t
>
(
dev_ctx
.
x_context
(),
softmax
->
data
<
float
>
(),
labels_int32
.
data
<
int32_t
>
(),
loss
->
data
<
float
>
(),
nullptr
,
n
,
d
,
ignore_index
);
r
=
xpu
::
hard_cross_entropy
<
XPUType
,
int32_t
>
(
dev_ctx
.
x_context
(),
softmax_data
,
labels_int32
.
data
<
int32_t
>
(),
loss_data
,
nullptr
,
n
,
d
,
ignore_index
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"hard_cross_entropy"
);
}
}
...
...
@@ -167,8 +181,9 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyXPUKernel
<
float
>
);
REGISTER_OP_XPU_KERNEL
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyXPUKernel
<
float
>
,
ops
::
SoftmaxWithCrossEntropyXPUKernel
<
paddle
::
platform
::
float16
>
);
REGISTER_OP_XPU_KERNEL
(
softmax_with_cross_entropy_grad
,
ops
::
SoftmaxWithCrossEntropyGradXPUKernel
<
float
>
,
...
...
paddle/fluid/platform/device/xpu/xpu2_op_list.h
浏览文件 @
649948a6
...
...
@@ -321,7 +321,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"softmax_with_cross_entropy"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"softplus"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"softplus_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录