Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7d15f930
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7d15f930
编写于
6月 24, 2022
作者:
C
cifar10
提交者:
GitHub
6月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mlu label_smooth kernel (#43743)
上级
d77c4955
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
209 addition
and
0 deletion
+209
-0
paddle/fluid/operators/label_smooth_op_mlu.cc
paddle/fluid/operators/label_smooth_op_mlu.cc
+90
-0
python/paddle/fluid/tests/unittests/mlu/test_label_smooth_op_mlu.py
...dle/fluid/tests/unittests/mlu/test_label_smooth_op_mlu.py
+119
-0
未找到文件。
paddle/fluid/operators/label_smooth_op_mlu.cc
0 → 100644
浏览文件 @
7d15f930
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
class
LabelSmoothMLUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in_t
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
dist_t
=
ctx
.
Input
<
Tensor
>
(
"PriorDist"
);
auto
*
out_t
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
epsilon_gt
=
1.0
f
-
epsilon
;
if
(
in_t
->
numel
()
==
0
)
return
;
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
label_dim
=
in_t
->
dims
()[
in_t
->
dims
().
size
()
-
1
];
MLUCnnlTensorDesc
x_desc
(
*
in_t
);
MLUCnnlTensorDesc
out_desc
(
*
out_t
);
auto
data_type
=
ToCnnlDataType
<
T
>
();
MLUCnnlOpTensorDesc
op_tensor_desc
(
CNNL_OP_TENSOR_ADD
,
data_type
,
CNNL_NOT_PROPAGATE_NAN
);
if
(
ctx
.
HasInput
(
"PriorDist"
))
{
MLUCnnlTensorDesc
dist_desc
(
*
dist_t
);
MLUCnnl
::
OpTensor
(
ctx
,
op_tensor_desc
.
get
(),
x_desc
.
get
(),
GetBasePtr
(
in_t
),
dist_desc
.
get
(),
GetBasePtr
(
dist_t
),
out_desc
.
get
(),
GetBasePtr
(
out_t
),
data_type
,
epsilon_gt
,
epsilon
);
}
else
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MLUDeviceContext
>();
framework
::
Tensor
dist_tensor
=
ctx
.
AllocateTmpTensor
<
T
,
MLUDeviceContext
>
({
1
,
label_dim
},
dev_ctx
);
MLUCnnlTensorDesc
dist_desc
(
dist_tensor
);
auto
value
=
static_cast
<
T
>
(
1.0
f
/
label_dim
);
MLUCnnl
::
Fill
(
ctx
,
CNNL_POINTER_MODE_HOST
,
&
value
,
dist_desc
.
get
(),
GetBasePtr
(
&
dist_tensor
));
MLUCnnl
::
OpTensor
(
ctx
,
op_tensor_desc
.
get
(),
x_desc
.
get
(),
GetBasePtr
(
in_t
),
dist_desc
.
get
(),
GetBasePtr
(
&
dist_tensor
),
out_desc
.
get
(),
GetBasePtr
(
out_t
),
data_type
,
epsilon_gt
,
epsilon
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_MLU_KERNEL
(
label_smooth
,
ops
::
LabelSmoothMLUKernel
<
float
>
,
ops
::
LabelSmoothMLUKernel
<
plat
::
float16
>
);
python/paddle/fluid/tests/unittests/mlu/test_label_smooth_op_mlu.py
0 → 100644
浏览文件 @
7d15f930
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
unittest
import
sys
sys
.
path
.
append
(
".."
)
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
SEED
=
2022
paddle
.
enable_static
()
class
TestLabelSmoothOp
(
OpTest
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
config
(
self
):
self
.
op_type
=
"label_smooth"
self
.
place
=
paddle
.
MLUPlace
(
0
)
self
.
__class__
.
use_mlu
=
True
self
.
__class__
.
no_need_check_grad
=
True
self
.
epsilon
=
0.1
batch_size
,
self
.
label_dim
=
10
,
12
np
.
random
.
seed
(
SEED
)
self
.
label
=
np
.
zeros
((
batch_size
,
self
.
label_dim
)).
astype
(
self
.
dtype
)
nonzero_index
=
np
.
random
.
randint
(
self
.
label_dim
,
size
=
(
batch_size
))
self
.
label
[
np
.
arange
(
batch_size
),
nonzero_index
]
=
1
def
setUp
(
self
):
self
.
init_dtype
()
self
.
config
()
smoothed_label
=
(
1
-
self
.
epsilon
)
*
self
.
label
+
self
.
epsilon
/
self
.
label_dim
smoothed_label
=
smoothed_label
.
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
self
.
label
}
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
}
self
.
outputs
=
{
'Out'
:
smoothed_label
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
class
TestLabelSmoothOpWithPriorDist
(
TestLabelSmoothOp
):
def
setUp
(
self
):
self
.
init_dtype
()
self
.
config
()
dist
=
np
.
random
.
random
((
1
,
self
.
label_dim
)).
astype
(
self
.
dtype
)
smoothed_label
=
(
1
-
self
.
epsilon
)
*
self
.
label
+
self
.
epsilon
*
dist
smoothed_label
=
smoothed_label
.
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
self
.
label
,
'PriorDist'
:
dist
}
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
}
self
.
outputs
=
{
'Out'
:
smoothed_label
}
class
TestLabelSmoothOp3D
(
TestLabelSmoothOp
):
def
setUp
(
self
):
super
(
TestLabelSmoothOp3D
,
self
).
setUp
()
self
.
inputs
[
'X'
]
=
self
.
inputs
[
'X'
].
reshape
(
[
2
,
-
1
,
self
.
inputs
[
'X'
].
shape
[
-
1
]])
self
.
outputs
[
'Out'
]
=
self
.
outputs
[
'Out'
].
reshape
(
self
.
inputs
[
'X'
].
shape
)
class
TestLabelSmoothOpWithPriorDist3D
(
TestLabelSmoothOpWithPriorDist
):
def
setUp
(
self
):
super
(
TestLabelSmoothOpWithPriorDist3D
,
self
).
setUp
()
self
.
inputs
[
'X'
]
=
self
.
inputs
[
'X'
].
reshape
(
[
2
,
-
1
,
self
.
inputs
[
'X'
].
shape
[
-
1
]])
self
.
outputs
[
'Out'
]
=
self
.
outputs
[
'Out'
].
reshape
(
self
.
inputs
[
'X'
].
shape
)
class
TestLabelSmoothOpFP16
(
TestLabelSmoothOp
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestLabelSmoothOpWithPriorDistFP16
(
TestLabelSmoothOpWithPriorDist
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestLabelSmoothOp3DFP16
(
TestLabelSmoothOp3D
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestLabelSmoothOpWithPriorDist3DFP16
(
TestLabelSmoothOpWithPriorDist3D
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录