未验证 提交 b666fd3c 编写于 作者: G Guoxia Wang 提交者: GitHub

support l2_normalize float16 (#35776)

* support fp16 dtype
上级 7babb3d2
......@@ -20,11 +20,17 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/norm_op.h"
#include "paddle/fluid/platform/bfloat16.h"
namespace paddle {
namespace operators {
__device__ __forceinline__ platform::float16 square_root(platform::float16 x) {
return static_cast<platform::float16>(sqrtf(static_cast<float>(x)));
}
__device__ __forceinline__ float square_root(float x) { return sqrtf(x); }
__device__ __forceinline__ double square_root(double x) { return sqrt(x); }
......@@ -33,28 +39,29 @@ template <typename T, int BlockDim>
__global__ void Normalize(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, const T eps, T* y, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T sum = 0.0;
__shared__ T norm;
MT sum = 0.0;
__shared__ MT norm;
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const T x_ij = x[base + j * post];
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += x_ij * x_ij;
}
T reduce_result = BlockReduce(temp_storage).Sum(sum);
MT reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) {
norm = square_root(reduce_result + eps);
out_norm[i] = norm;
norm = square_root(reduce_result + static_cast<MT>(eps));
out_norm[i] = static_cast<T>(norm);
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const int index = base + j * post;
y[index] = x[index] / norm;
y[index] = static_cast<T>((static_cast<MT>(x[index]) / norm));
}
}
}
......@@ -109,34 +116,36 @@ template <typename T, int BlockDim>
__global__ void NormalizeGradient(const T* x, const T* x_norm, const T* y_grad,
const int pre, const int axis_n,
const int post, T* x_grad) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
T sum = 0.0;
__shared__ T row_sum;
__shared__ T row_sqrt_norm;
__shared__ T row_norm;
MT sum = 0.0;
__shared__ MT row_sum;
__shared__ MT row_sqrt_norm;
__shared__ MT row_norm;
auto base = (i / post) * post * axis_n + (i % post);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
sum += x[index] * y_grad[index];
sum += static_cast<MT>(x[index]) * static_cast<MT>(y_grad[index]);
}
T reduce_result = BlockReduce(temp_storage_sum).Sum(sum);
MT reduce_result = BlockReduce(temp_storage_sum).Sum(sum);
if (threadIdx.x == 0) {
row_sum = reduce_result;
row_sqrt_norm = x_norm[i];
row_sqrt_norm = static_cast<MT>(x_norm[i]);
row_norm = row_sqrt_norm * row_sqrt_norm;
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const T x_ij = x[index];
const T dy_ij = y_grad[index];
x_grad[index] = (dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm;
const MT x_ij = static_cast<MT>(x[index]);
const MT dy_ij = static_cast<MT>(y_grad[index]);
x_grad[index] =
static_cast<T>((dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm);
}
}
}
......@@ -181,7 +190,11 @@ class NormGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(norm, ops::NormCUDAKernel<CUDA, float>,
REGISTER_OP_CUDA_KERNEL(norm,
ops::NormCUDAKernel<CUDA, paddle::platform::float16>,
ops::NormCUDAKernel<CUDA, float>,
ops::NormCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradCUDAKernel<CUDA, float>,
ops::NormGradCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(
norm_grad, ops::NormGradCUDAKernel<CUDA, paddle::platform::float16>,
ops::NormGradCUDAKernel<CUDA, float>,
ops::NormGradCUDAKernel<CUDA, double>);
......@@ -5041,64 +5041,42 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
slice along dimension `axis`.
Args:
x(Variable|list): The input tensor could be N-D tensor, and the input data type could be float32 or float64.
x(Variable|list): The input tensor could be N-D tensor, and the input data type could be float16, float32 or float64.
axis(int): The axis on which to apply normalization. If `axis < 0`, \
the dimension to normalization is rank(X) + axis. -1 is the
last dimension.
epsilon(float): The epsilon value is used to avoid division by zero, \
the default value is 1e-12.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: The output has the same shape and data type with `x`.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
.. code-block:: python
:name: code-example1
import paddle
paddle.enable_static()
input = fluid.data(name="input", shape=[2,3])
output = fluid.layers.l2_normalize(x=input,axis=0)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.random.rand(2,3).astype("float32")
print(input_data)
# [[0.5171216 0.12704141 0.56018186]
# [0.93251234 0.5382788 0.81709313]]
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data},
fetch_list=[output],
return_numpy=True)
print(output_data)
# [array([[0.48496857, 0.22970329, 0.56545246],
# [0.8745316 , 0.9732607 , 0.82478094]], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
output = fluid.layers.l2_normalize(x=input, axis=-1)
print(output.numpy())
X = paddle.randn(shape=[3, 5], dtype='float64')
out = paddle.fluid.layers.l2_normalize(X, axis=-1)
print(out.numpy())
# [[0.66907585 0.16437206 0.7247892 ]
# [0.6899054 0.3982376 0.6045142 ]]
# [[ 0.21558504 0.56360189 0.47466096 0.46269539 -0.44326736]
# [-0.70602414 -0.52745777 0.37771788 -0.2804768 -0.04449922]
# [-0.33972208 -0.43014923 0.31772556 0.76617881 -0.10761525]]
"""
if len(x.shape) == 1:
axis = 0
check_variable_and_dtype(x, "X", ("float32", "float64"), "norm")
if in_dygraph_mode():
_, out = _C_ops.norm(x, 'axis', 1
if axis is None else axis, 'epsilon', epsilon)
return out
check_variable_and_dtype(x, "X", ("float16", "float32", "float64"), "norm")
helper = LayerHelper("l2_normalize", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
......@@ -33,23 +33,27 @@ class TestNormOp(OpTest):
def setUp(self):
self.op_type = "norm"
self.init_test_case()
x = np.random.random(self.shape).astype("float64")
self.init_dtype()
x = np.random.random(self.shape).astype(self.dtype)
y, norm = l2_norm(x, self.axis, self.epsilon)
self.inputs = {'X': x}
self.attrs = {'epsilon': self.epsilon, 'axis': self.axis}
self.outputs = {'Out': y, 'Norm': norm}
def test_check_output(self):
self.check_output()
self.check_output(atol=1e-5)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', max_relative_error=0.008)
def init_test_case(self):
self.shape = [2, 3, 4, 5]
self.axis = 1
self.epsilon = 1e-8
def init_dtype(self):
self.dtype = "float64"
class TestNormOp2(TestNormOp):
def init_test_case(self):
......@@ -89,6 +93,25 @@ class TestNormOp5(TestNormOp):
pass
class TestNormOp6(TestNormOp):
def init_dtype(self):
self.dtype = "float32"
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestNormOp7(TestNormOp):
def init_dtype(self):
self.dtype = "float16"
def test_check_output(self):
self.check_output_with_place(fluid.core.CUDAPlace(0), atol=5e-2)
def test_check_grad(self):
self.check_grad_with_place(
fluid.core.CUDAPlace(0), ['X'], 'Out', max_relative_error=0.05)
@skip_check_grad_ci(reason="skip check grad for test mode.")
class TestNormTestOp(OpTest):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册