未验证 提交 9cd99f7e 编写于 作者: I Infinity_lee 提交者: GitHub

【hackathon 4 No53】label_smooth add fp16 support (#51493)

上级 775fb43a
...@@ -15,20 +15,22 @@ ...@@ -15,20 +15,22 @@
#include "paddle/phi/kernels/label_smooth_grad_kernel.h" #include "paddle/phi/kernels/label_smooth_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi { namespace phi {
template <typename T> template <typename T>
struct LabelSmoothGradFunctor { struct LabelSmoothGradFunctor {
T epsilon; using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType epsilon;
__forceinline__ LabelSmoothGradFunctor(float epsilon_data) { __forceinline__ LabelSmoothGradFunctor(float epsilon_data) {
epsilon = static_cast<T>(epsilon_data); epsilon = static_cast<MPType>(epsilon_data);
} }
__device__ __forceinline__ T operator()(const T x) const { __device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(1 - epsilon) * x; return static_cast<T>((static_cast<MPType>(1) - epsilon) *
static_cast<MPType>(x));
} }
}; };
...@@ -52,4 +54,5 @@ PD_REGISTER_KERNEL(label_smooth_grad, ...@@ -52,4 +54,5 @@ PD_REGISTER_KERNEL(label_smooth_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::LabelSmoothGradKernel, phi::LabelSmoothGradKernel,
float, float,
double) {} double,
phi::dtype::float16) {}
...@@ -17,24 +17,27 @@ ...@@ -17,24 +17,27 @@
#include <vector> #include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi { namespace phi {
template <typename T> template <typename T>
struct LabelSmoothFunctor { struct LabelSmoothFunctor {
T epsilon; using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
T label_dim; MPType epsilon;
MPType label_dim;
__forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) { __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
epsilon = static_cast<T>(epsilon_data); epsilon = static_cast<MPType>(epsilon_data);
label_dim = static_cast<T>(label_dim_data); label_dim = static_cast<MPType>(label_dim_data);
} }
__device__ __forceinline__ T operator()(const T x) const { __device__ __forceinline__ T operator()(const T x) const {
return (static_cast<T>(1 - epsilon) * x + return static_cast<T>(
static_cast<T>(epsilon / label_dim)); static_cast<MPType>(static_cast<MPType>(1) - epsilon) *
static_cast<MPType>(x) +
static_cast<MPType>(epsilon / label_dim));
} }
}; };
...@@ -45,10 +48,14 @@ __global__ void LabelSmoothRunDistKernel(const int N, ...@@ -45,10 +48,14 @@ __global__ void LabelSmoothRunDistKernel(const int N,
const T* src, const T* src,
const T* dist_data, const T* dist_data,
T* dst) { T* dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
CUDA_KERNEL_LOOP(idx, N) { CUDA_KERNEL_LOOP(idx, N) {
int dist_idx = idx % dist_numel; int dist_idx = idx % dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] + dst[idx] =
static_cast<T>(epsilon) * dist_data[dist_idx]; static_cast<T>((static_cast<MPType>(1) - static_cast<MPType>(epsilon)) *
static_cast<MPType>(src[idx]) +
static_cast<MPType>(epsilon) *
static_cast<MPType>(dist_data[dist_idx]));
} }
} }
...@@ -83,5 +90,10 @@ void LabelSmoothKernel(const Context& ctx, ...@@ -83,5 +90,10 @@ void LabelSmoothKernel(const Context& ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(label_smooth,
label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::LabelSmoothKernel,
float,
double,
phi::dtype::float16) {}
...@@ -24,9 +24,10 @@ class TestLabelSmoothOp(OpTest): ...@@ -24,9 +24,10 @@ class TestLabelSmoothOp(OpTest):
def config(self): def config(self):
self.op_type = "label_smooth" self.op_type = "label_smooth"
self.python_api = paddle.nn.functional.label_smooth self.python_api = paddle.nn.functional.label_smooth
self.init_dtype()
self.epsilon = 0.1 self.epsilon = 0.1
batch_size, self.label_dim = 10, 12 batch_size, self.label_dim = 10, 12
self.label = np.zeros((batch_size, self.label_dim)).astype("float64") self.label = np.zeros((batch_size, self.label_dim)).astype(self.dtype)
nonzero_index = np.random.randint(self.label_dim, size=(batch_size)) nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
self.label[np.arange(batch_size), nonzero_index] = 1 self.label[np.arange(batch_size), nonzero_index] = 1
...@@ -39,6 +40,9 @@ class TestLabelSmoothOp(OpTest): ...@@ -39,6 +40,9 @@ class TestLabelSmoothOp(OpTest):
self.attrs = {'epsilon': self.epsilon} self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label} self.outputs = {'Out': smoothed_label}
def init_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -46,6 +50,11 @@ class TestLabelSmoothOp(OpTest): ...@@ -46,6 +50,11 @@ class TestLabelSmoothOp(OpTest):
self.check_grad(["X"], "Out", check_eager=True) self.check_grad(["X"], "Out", check_eager=True)
class TestLabelSmoothFP16OP(TestLabelSmoothOp):
def init_dtype(self):
self.dtype = np.float16
class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp): class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp):
def setUp(self): def setUp(self):
self.config() self.config()
......
...@@ -1923,7 +1923,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): ...@@ -1923,7 +1923,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
label(Tensor): The input variable containing the label data. The label(Tensor): The input variable containing the label data. The
label data should use one-hot representation. It's label data should use one-hot representation. It's
a multidimensional tensor with a shape of a multidimensional tensor with a shape of
:math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float32" and "float64". :math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float16" "float32" and "float64".
prior_dist(Tensor, optional): The prior distribution to be used to smooth prior_dist(Tensor, optional): The prior distribution to be used to smooth
labels. If not provided, an uniform distribution labels. If not provided, an uniform distribution
is used. It's a multidimensional tensor with a shape of is used. It's a multidimensional tensor with a shape of
...@@ -1965,7 +1965,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): ...@@ -1965,7 +1965,7 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
) )
check_variable_and_dtype( check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'label_smooth' label, 'label', ['float16', 'float32', 'float64'], 'label_smooth'
) )
helper = LayerHelper("label_smooth", **locals()) helper = LayerHelper("label_smooth", **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册