未验证 提交 ce4637c1 编写于 作者: 2 201716010711 提交者: GitHub

support fp16 squaredl2norm (#48315)

上级 2e156ac8
......@@ -15,12 +15,46 @@
#include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
namespace phi {
/**
* x*y*2.0
*/
template <typename T>
struct DoubleMulFunctor {
__device__ __forceinline__ T operator()(const T a, const T b) const {
return b * a * static_cast<T>(2.0f);
}
};
template <typename T, typename Context>
void SquaredL2NormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);
PADDLE_ENFORCE_EQ(
dout.numel(),
1,
phi::errors::InvalidArgument(
"Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar."));
std::vector<const DenseTensor*> ins{&x, &dout};
std::vector<DenseTensor*> outs{dx};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor<T>());
}
} // namespace phi
PD_REGISTER_KERNEL(squared_l2_norm_grad,
GPU,
ALL_LAYOUT,
phi::SquaredL2NormGradKernel,
float,
double) {}
double,
phi::dtype::float16) {}
......@@ -15,9 +15,30 @@
#include "paddle/phi/kernels/squared_l2_norm_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h"
PD_REGISTER_KERNEL(
squared_l2_norm, GPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) {
#include "paddle/phi/kernels/funcs/reduce_function.h"
namespace phi {
template <typename T, typename Context>
void SquaredL2NormKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
std::vector<int> origin_reduce_dims;
for (size_t i = 0; i < x.dims().size(); i++) {
origin_reduce_dims.push_back(i);
}
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::SquareFunctor<T, T>>(
dev_ctx, x, out, kps::SquareFunctor<T, T>(), origin_reduce_dims, false);
}
} // namespace phi
PD_REGISTER_KERNEL(squared_l2_norm,
GPU,
ALL_LAYOUT,
phi::SquaredL2NormKernel,
float,
double,
phi::dtype::float16) {}
......@@ -254,10 +254,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
self.assertListEqual(
ops,
[
'square',
'reduce_sum',
'square',
'reduce_sum',
'squared_l2_norm',
'squared_l2_norm',
'sum',
'cast',
'sqrt',
......
......@@ -30,6 +30,46 @@ def test_squared_l2_norm(x):
return _legacy_C_ops.squared_l2_norm(x)
class TestSquaredL2NormF16Op(unittest.TestCase):
def init_test_case(self):
X = np.random.uniform(-0.1, 0.1, (8, 5, 10)).astype('float32')
return X
def check_main(self, x_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = test_squared_l2_norm(x)
x_g = paddle.grad(y, [x])
paddle.enable_static()
return y, x_g
def test_main(self):
x_np = self.init_test_case()
y_np_1, x_g_np_1 = self.check_main(x_np, 'float32')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float16')
def assert_equal(x, y):
np.testing.assert_allclose(x, y, rtol=1e-05, atol=0.0)
assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
class TestSquaredL2NormF16Op1(TestSquaredL2NormF16Op):
def init_test_case(self):
X = np.random.uniform(-2.0, 2.0, (30, 10)).astype('float32')
return X
class TestSquaredL2NormF16Op2(TestSquaredL2NormF16Op):
def init_test_case(self):
X = np.random.uniform(-5.0, 5.0, (20, 10, 20)).astype('float32')
return X
class TestL2LossOp(OpTest):
"""Test squared_l2_norm"""
......
......@@ -207,11 +207,8 @@ def _squared_l2_norm(x):
"""
x = _cast_to_mp_type_if_enabled(x)
if (
core.is_compiled_with_xpu()
or x.dtype == core.VarDesc.VarType.FP16
or x.dtype == core.VarDesc.VarType.BF16
):
if core.is_compiled_with_xpu():
square = paddle.square(x)
sum_square = paddle.sum(square)
return sum_square
......@@ -220,7 +217,7 @@ def _squared_l2_norm(x):
return _C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'float16'], op_type)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册