Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b7bcd0f6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b7bcd0f6
编写于
2月 23, 2022
作者:
A
Aurelius84
提交者:
GitHub
2月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Migrate lable_smooth_op into Phi (#39796)
* [Phi] Migrate lable_smooth_op into Phi * fix PT->PD
上级
24f55aed
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
334 addition
and
208 deletion
+334
-208
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+1
-1
paddle/fluid/operators/label_smooth_op.cc
paddle/fluid/operators/label_smooth_op.cc
+1
-10
paddle/fluid/operators/label_smooth_op.cu
paddle/fluid/operators/label_smooth_op.cu
+0
-125
paddle/fluid/operators/label_smooth_op.h
paddle/fluid/operators/label_smooth_op.h
+0
-70
paddle/fluid/operators/label_smooth_op_npu.cc
paddle/fluid/operators/label_smooth_op_npu.cc
+1
-1
paddle/fluid/operators/label_smooth_op_xpu.cc
paddle/fluid/operators/label_smooth_op_xpu.cc
+0
-1
paddle/phi/kernels/cpu/label_smooth_grad_kernel.cc
paddle/phi/kernels/cpu/label_smooth_grad_kernel.cc
+45
-0
paddle/phi/kernels/cpu/label_smooth_kernel.cc
paddle/phi/kernels/cpu/label_smooth_kernel.cc
+50
-0
paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
+55
-0
paddle/phi/kernels/gpu/label_smooth_kernel.cu
paddle/phi/kernels/gpu/label_smooth_kernel.cu
+86
-0
paddle/phi/kernels/label_smooth_grad_kernel.h
paddle/phi/kernels/label_smooth_grad_kernel.h
+28
-0
paddle/phi/kernels/label_smooth_kernel.h
paddle/phi/kernels/label_smooth_kernel.h
+30
-0
paddle/phi/ops/compat/label_smooth_sig.cc
paddle/phi/ops/compat/label_smooth_sig.cc
+37
-0
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
b7bcd0f6
...
@@ -2040,7 +2040,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
...
@@ -2040,7 +2040,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
(
i
==
0
?
0
:
pt_kernel_context
->
InputRangeAt
(
i
-
1
).
second
);
(
i
==
0
?
0
:
pt_kernel_context
->
InputRangeAt
(
i
-
1
).
second
);
// deal with optional here
// deal with optional here
if
((
it
==
ctx
.
inputs
.
end
())
&&
if
((
it
==
ctx
.
inputs
.
end
()
||
it
->
second
.
size
()
==
0
)
&&
(
input_defs
[
i
].
type_index
==
(
input_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
paddle
::
optional
<
const
phi
::
DenseTensor
&>
))))
{
std
::
type_index
(
typeid
(
paddle
::
optional
<
const
phi
::
DenseTensor
&>
))))
{
pt_kernel_context
->
EmplaceBackInputWithoutSetRange
(
nullptr
);
pt_kernel_context
->
EmplaceBackInputWithoutSetRange
(
nullptr
);
...
...
paddle/fluid/operators/label_smooth_op.cc
浏览文件 @
b7bcd0f6
...
@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/label_smooth_op.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -152,11 +151,3 @@ REGISTER_OPERATOR(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker,
...
@@ -152,11 +151,3 @@ REGISTER_OPERATOR(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker,
ops
::
LabelSmoothGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
LabelSmoothGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
LabelSmoothGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
LabelSmoothGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
label_smooth_grad
,
ops
::
LabelSmoothGradOp
);
REGISTER_OPERATOR
(
label_smooth_grad
,
ops
::
LabelSmoothGradOp
);
REGISTER_OP_CPU_KERNEL
(
label_smooth
,
ops
::
LabelSmoothKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LabelSmoothKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
label_smooth_grad
,
ops
::
LabelSmoothGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LabelSmoothGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/label_smooth_op.cu
已删除
100644 → 0
浏览文件 @
24f55aed
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/label_smooth_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
LabelSmoothFunctor
{
T
epsilon
;
T
label_dim
;
__forceinline__
LabelSmoothFunctor
(
float
epsilon_data
,
int
label_dim_data
)
{
epsilon
=
static_cast
<
T
>
(
epsilon_data
);
label_dim
=
static_cast
<
T
>
(
label_dim_data
);
}
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
(
static_cast
<
T
>
(
1
-
epsilon
)
*
x
+
static_cast
<
T
>
(
epsilon
/
label_dim
));
}
};
template
<
typename
T
>
struct
LabelSmoothGradFunctor
{
T
epsilon
;
__forceinline__
LabelSmoothGradFunctor
(
float
epsilon_data
)
{
epsilon
=
static_cast
<
T
>
(
epsilon_data
);
}
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
static_cast
<
T
>
(
1
-
epsilon
)
*
x
;
}
};
template
<
typename
T
>
__global__
void
LabelSmoothRunDistKernel
(
const
int
N
,
const
float
epsilon
,
const
int
dist_numel
,
const
T
*
src
,
const
T
*
dist_data
,
T
*
dst
)
{
CUDA_KERNEL_LOOP
(
idx
,
N
)
{
int
dist_idx
=
idx
%
dist_numel
;
dst
[
idx
]
=
static_cast
<
T
>
(
1
-
epsilon
)
*
src
[
idx
]
+
static_cast
<
T
>
(
epsilon
)
*
dist_data
[
dist_idx
];
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
LabelSmoothGPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
dist_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"PriorDist"
);
auto
label_dim
=
in_t
->
dims
()[
in_t
->
dims
().
size
()
-
1
];
auto
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
&
dev
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
size_prob
=
in_t
->
numel
();
const
T
*
in_data
=
in_t
->
data
<
T
>
();
T
*
out_data
=
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
threads
=
512
;
int
grid
=
(
size_prob
+
threads
-
1
)
/
threads
;
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
if
(
dist_t
)
{
auto
dist_numel
=
dist_t
->
numel
();
const
T
*
dist_data
=
dist_t
->
data
<
T
>
();
LabelSmoothRunDistKernel
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
size_prob
,
epsilon
,
dist_numel
,
in_data
,
dist_data
,
out_data
);
}
else
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
in_t
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out_t
};
auto
functor
=
LabelSmoothFunctor
<
T
>
(
epsilon
,
label_dim
);
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
LabelSmoothGradGPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_in_t
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_in_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
d_out_t
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
d_in_t
};
auto
functor
=
LabelSmoothGradFunctor
<
T
>
(
epsilon
);
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
label_smooth
,
ops
::
LabelSmoothGPUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
LabelSmoothGPUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
label_smooth_grad
,
ops
::
LabelSmoothGradGPUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
LabelSmoothGradGPUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/label_smooth_op.h
已删除
100644 → 0
浏览文件 @
24f55aed
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
LabelSmoothKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
dist_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"PriorDist"
);
auto
label_dim
=
in_t
->
dims
()[
in_t
->
dims
().
size
()
-
1
];
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
label_dim
!=
0
)
{
auto
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out_t
);
auto
in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in_t
);
auto
&
dev
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
if
(
dist_t
)
{
auto
dist
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dist_t
);
out
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
in
+
static_cast
<
T
>
(
epsilon
)
*
dist
.
broadcast
(
Eigen
::
DSizes
<
int
,
1
>
(
in_t
->
numel
()
/
label_dim
));
}
else
{
out
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
in
+
static_cast
<
T
>
(
epsilon
/
label_dim
);
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
LabelSmoothGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_in_t
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_in_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
d_out_dim
=
d_out_t
->
dims
()[
d_out_t
->
dims
().
size
()
-
1
];
if
(
d_out_dim
!=
0
)
{
auto
d_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_out_t
);
auto
d_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_in_t
);
auto
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
&
dev
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
d_in
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
d_out
;
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/label_smooth_op_npu.cc
浏览文件 @
b7bcd0f6
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/
operators/label_smooth_op
.h"
#include "paddle/fluid/
framework/op_registry
.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/operators/label_smooth_op_xpu.cc
浏览文件 @
b7bcd0f6
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/label_smooth_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/phi/kernels/cpu/label_smooth_grad_kernel.cc
0 → 100644
浏览文件 @
b7bcd0f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
LabelSmoothGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
out_grad
,
float
epsilon
,
DenseTensor
*
label_grad
)
{
ctx
.
template
Alloc
<
T
>(
label_grad
);
auto
d_out_dim
=
out_grad
.
dims
()[
out_grad
.
dims
().
size
()
-
1
];
if
(
d_out_dim
!=
0
)
{
auto
d_out
=
EigenVector
<
T
>::
Flatten
(
out_grad
);
auto
d_in
=
EigenVector
<
T
>::
Flatten
(
*
label_grad
);
auto
&
dev
=
*
ctx
.
eigen_device
();
d_in
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
d_out
;
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
label_smooth_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
LabelSmoothGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/cpu/label_smooth_kernel.cc
0 → 100644
浏览文件 @
b7bcd0f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/label_smooth_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
LabelSmoothKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
label
,
paddle
::
optional
<
const
DenseTensor
&>
prior_dist
,
float
epsilon
,
DenseTensor
*
out
)
{
auto
label_dim
=
label
.
dims
()[
label
.
dims
().
size
()
-
1
];
ctx
.
template
Alloc
<
T
>(
out
);
auto
&
dev
=
*
ctx
.
eigen_device
();
if
(
label_dim
!=
0
)
{
auto
eigen_out
=
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
eigen_in
=
EigenVector
<
T
>::
Flatten
(
label
);
if
(
prior_dist
.
is_initialized
())
{
auto
dist
=
EigenVector
<
T
>::
Flatten
(
*
prior_dist
.
get_ptr
());
eigen_out
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
eigen_in
+
static_cast
<
T
>
(
epsilon
)
*
dist
.
broadcast
(
Eigen
::
DSizes
<
int
,
1
>
(
label
.
numel
()
/
label_dim
));
}
else
{
eigen_out
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
eigen_in
+
static_cast
<
T
>
(
epsilon
/
label_dim
);
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
label_smooth
,
CPU
,
ALL_LAYOUT
,
phi
::
LabelSmoothKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/label_smooth_grad_kernel.cu
0 → 100644
浏览文件 @
b7bcd0f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"
namespace
phi
{
template
<
typename
T
>
struct
LabelSmoothGradFunctor
{
T
epsilon
;
__forceinline__
LabelSmoothGradFunctor
(
float
epsilon_data
)
{
epsilon
=
static_cast
<
T
>
(
epsilon_data
);
}
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
static_cast
<
T
>
(
1
-
epsilon
)
*
x
;
}
};
template
<
typename
T
,
typename
Context
>
void
LabelSmoothGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
out_grad
,
float
epsilon
,
DenseTensor
*
label_grad
)
{
ctx
.
template
Alloc
<
T
>(
label_grad
);
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
out_grad
};
std
::
vector
<
DenseTensor
*>
outs
=
{
label_grad
};
auto
functor
=
LabelSmoothGradFunctor
<
T
>
(
epsilon
);
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
T
>
(
ctx
,
ins
,
&
outs
,
functor
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
label_smooth_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
LabelSmoothGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/label_smooth_kernel.cu
0 → 100644
浏览文件 @
b7bcd0f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/label_smooth_kernel.h"
namespace
phi
{
template
<
typename
T
>
struct
LabelSmoothFunctor
{
T
epsilon
;
T
label_dim
;
__forceinline__
LabelSmoothFunctor
(
float
epsilon_data
,
int
label_dim_data
)
{
epsilon
=
static_cast
<
T
>
(
epsilon_data
);
label_dim
=
static_cast
<
T
>
(
label_dim_data
);
}
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
(
static_cast
<
T
>
(
1
-
epsilon
)
*
x
+
static_cast
<
T
>
(
epsilon
/
label_dim
));
}
};
template
<
typename
T
>
__global__
void
LabelSmoothRunDistKernel
(
const
int
N
,
const
float
epsilon
,
const
int
dist_numel
,
const
T
*
src
,
const
T
*
dist_data
,
T
*
dst
)
{
CUDA_KERNEL_LOOP
(
idx
,
N
)
{
int
dist_idx
=
idx
%
dist_numel
;
dst
[
idx
]
=
static_cast
<
T
>
(
1
-
epsilon
)
*
src
[
idx
]
+
static_cast
<
T
>
(
epsilon
)
*
dist_data
[
dist_idx
];
}
}
template
<
typename
T
,
typename
Context
>
void
LabelSmoothKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
label
,
paddle
::
optional
<
const
DenseTensor
&>
prior_dist
,
float
epsilon
,
DenseTensor
*
out
)
{
auto
label_dim
=
label
.
dims
()[
label
.
dims
().
size
()
-
1
];
auto
size_prob
=
label
.
numel
();
const
T
*
in_data
=
label
.
data
<
T
>
();
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
if
(
prior_dist
.
get_ptr
())
{
int
threads
=
512
;
int
grid
=
(
size_prob
+
threads
-
1
)
/
threads
;
auto
stream
=
ctx
.
stream
();
const
auto
*
dist_t
=
prior_dist
.
get_ptr
();
auto
dist_numel
=
dist_t
->
numel
();
const
T
*
dist_data
=
dist_t
->
data
<
T
>
();
LabelSmoothRunDistKernel
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
size_prob
,
epsilon
,
dist_numel
,
in_data
,
dist_data
,
out_data
);
}
else
{
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
label
};
std
::
vector
<
DenseTensor
*>
outs
=
{
out
};
auto
functor
=
LabelSmoothFunctor
<
T
>
(
epsilon
,
label_dim
);
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
T
>
(
ctx
,
ins
,
&
outs
,
functor
);
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
label_smooth
,
GPU
,
ALL_LAYOUT
,
phi
::
LabelSmoothKernel
,
float
,
double
)
{}
paddle/phi/kernels/label_smooth_grad_kernel.h
0 → 100644
浏览文件 @
b7bcd0f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
LabelSmoothGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
out_grad
,
float
epsilon
,
DenseTensor
*
label_grad
);
}
// namespace phi
paddle/phi/kernels/label_smooth_kernel.h
0 → 100644
浏览文件 @
b7bcd0f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/utils/optional.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
LabelSmoothKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
label
,
paddle
::
optional
<
const
DenseTensor
&>
prior_dist
,
float
epsilon
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/ops/compat/label_smooth_sig.cc
0 → 100644
浏览文件 @
b7bcd0f6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
LabelSmoothOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"label_smooth"
,
{
"X"
,
"PriorDist"
},
{
"epsilon"
},
{
"Out"
});
}
KernelSignature
LabelSmoothGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"label_smooth_grad"
,
{
GradVarName
(
"Out"
)},
{
"epsilon"
},
{
GradVarName
(
"X"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
label_smooth
,
phi
::
LabelSmoothOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
label_smooth_grad
,
phi
::
LabelSmoothGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录