未验证 提交 910e1b6a 编写于 作者: X xiaohemaikoo 提交者: GitHub

logsumexp support fp16 (#45817)

上级 e86dbd62
...@@ -15,8 +15,16 @@ ...@@ -15,8 +15,16 @@
#include "paddle/phi/kernels/logsumexp_grad_kernel.h" #include "paddle/phi/kernels/logsumexp_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"
PD_REGISTER_KERNEL( using float16 = phi::dtype::float16;
logsumexp_grad, GPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {}
PD_REGISTER_KERNEL(logsumexp_grad,
GPU,
ALL_LAYOUT,
phi::LogsumexpGradKernel,
float,
double,
float16) {}
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
#include "paddle/phi/kernels/logsumexp_kernel.h" #include "paddle/phi/kernels/logsumexp_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" #include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h"
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {} logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double, float16) {}
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h" #include "paddle/phi/kernels/funcs/reduce_grad_functions.h"
...@@ -23,6 +24,7 @@ ...@@ -23,6 +24,7 @@
namespace phi { namespace phi {
template <typename T>
struct LogsumexpGradFunctor { struct LogsumexpGradFunctor {
template <typename Context, template <typename Context,
typename X, typename X,
...@@ -37,7 +39,13 @@ struct LogsumexpGradFunctor { ...@@ -37,7 +39,13 @@ struct LogsumexpGradFunctor {
DY* dy, DY* dy,
const Dim& dim, const Dim& dim,
int size) { int size) {
dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); using MT = typename phi::dtype::MPTypeTrait<T>::Type;
auto x_mt = (*x).template cast<MT>();
auto y_mt = (*y).template cast<MT>();
auto dy_mt = (*dy).template cast<MT>();
dx->device(place) =
(dy_mt.broadcast(dim) * (x_mt - y_mt.broadcast(dim)).exp())
.template cast<T>();
} }
}; };
...@@ -62,11 +70,11 @@ void LogsumexpGradKernel(const Context& dev_ctx, ...@@ -62,11 +70,11 @@ void LogsumexpGradKernel(const Context& dev_ctx,
auto dx = phi::EigenVector<T>::Flatten(*in_grad); auto dx = phi::EigenVector<T>::Flatten(*in_grad);
auto& place = *dev_ctx.eigen_device(); auto& place = *dev_ctx.eigen_device();
auto broadcast_dim = Eigen::array<int, 1>({{static_cast<int>(in.numel())}}); auto broadcast_dim = Eigen::array<int, 1>({{static_cast<int>(in.numel())}});
LogsumexpGradFunctor()( LogsumexpGradFunctor<T>()(
place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]); place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]);
} else { } else {
int rank = in.dims().size(); int rank = in.dims().size();
LogsumexpGradFunctor functor; LogsumexpGradFunctor<T> functor;
std::vector<int32_t> axis32; std::vector<int32_t> axis32;
axis32.reserve(axis.size()); axis32.reserve(axis.size());
std::for_each(axis.begin(), axis.end(), [&axis32](const int64_t& t) { std::for_each(axis.begin(), axis.end(), [&axis32](const int64_t& t) {
...@@ -74,21 +82,26 @@ void LogsumexpGradKernel(const Context& dev_ctx, ...@@ -74,21 +82,26 @@ void LogsumexpGradKernel(const Context& dev_ctx,
}); });
switch (rank) { switch (rank) {
case 1: case 1:
phi::funcs::ReduceGradFunctor<Context, T, 1, LogsumexpGradFunctor>( phi::funcs::ReduceGradFunctor<Context, T, 1, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32); dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break; break;
case 2: case 2:
phi::funcs::ReduceGradFunctor<Context, T, 2, LogsumexpGradFunctor>( phi::funcs::ReduceGradFunctor<Context, T, 2, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32); dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break; break;
case 3: case 3:
phi::funcs::ReduceGradFunctor<Context, T, 3, LogsumexpGradFunctor>( phi::funcs::ReduceGradFunctor<Context, T, 3, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32); dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break; break;
case 4: case 4:
phi::funcs::ReduceGradFunctor<Context, T, 4, LogsumexpGradFunctor>( phi::funcs::ReduceGradFunctor<Context, T, 4, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32); dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break; break;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported dimensions, please keep maximum dimensions of input "
"data less than 4."));
break;
} }
} }
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h"
...@@ -23,15 +24,17 @@ ...@@ -23,15 +24,17 @@
namespace phi { namespace phi {
#define HANDLE_DIM(NDIM, RDIM) \ #define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \ if (ndim == NDIM && rdim == RDIM) { \
funcs::ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor>( \ funcs::ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor<T>>( \
dev_ctx, x, out, axis, keepdim); \ dev_ctx, x, out, axis, keepdim); \
} }
template <typename T>
struct LogsumexpFunctor { struct LogsumexpFunctor {
template <typename Context, typename X, typename Y, typename Dim> template <typename Context, typename X, typename Y, typename Dim>
void operator()(const Context& place, X* x, Y* y, const Dim& dim) { void operator()(const Context& place, X* x, Y* y, const Dim& dim) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
auto x_dim = x->dimensions(); auto x_dim = x->dimensions();
auto t_dim = x_dim; auto t_dim = x_dim;
for (int i = 0; i < static_cast<int>(dim.size()); i++) { for (int i = 0; i < static_cast<int>(dim.size()); i++) {
...@@ -46,12 +49,14 @@ struct LogsumexpFunctor { ...@@ -46,12 +49,14 @@ struct LogsumexpFunctor {
r_dim[dim[i]] = x_dim[dim[i]]; r_dim[dim[i]] = x_dim[dim[i]];
} }
auto x_mt = (*x).template cast<MT>();
auto y_dim = y->dimensions(); auto y_dim = y->dimensions();
auto x_max = x->maximum(dim); auto x_max = x_mt.maximum(dim);
y->device(place) = y->device(place) =
(x_max + (x_max +
(*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) (x_mt - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log())
.reshape(y_dim); .reshape(y_dim)
.template cast<T>();
} }
}; };
...@@ -74,10 +79,16 @@ void LogsumexpKernel(const Context& dev_ctx, ...@@ -74,10 +79,16 @@ void LogsumexpKernel(const Context& dev_ctx,
auto output = phi::EigenScalar<T>::From(*out); auto output = phi::EigenScalar<T>::From(*out);
auto& place = *dev_ctx.eigen_device(); auto& place = *dev_ctx.eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}}); auto reduce_dim = Eigen::array<int, 1>({{0}});
LogsumexpFunctor()(place, &input, &output, reduce_dim); LogsumexpFunctor<T>()(place, &input, &output, reduce_dim);
} else { } else {
int ndim = input_dim_size; int ndim = input_dim_size;
int rdim = axis.size(); int rdim = axis.size();
if (ndim > 4) {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported dimensions, please keep maximum dimensions of input "
"data less than 4."));
}
// comments for accelerating compiling temporarily. // comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5); // HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4); // HANDLE_DIM(6, 4);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle import paddle
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -35,6 +36,22 @@ def logsumexp_wrapper(x, axis=None, keepdim=False, allreduce=False): ...@@ -35,6 +36,22 @@ def logsumexp_wrapper(x, axis=None, keepdim=False, allreduce=False):
return paddle.logsumexp(x, axis, keepdim) return paddle.logsumexp(x, axis, keepdim)
def logsumexp_op_grad(x, axis=None, keepdim=False, reduce_all=False):
paddle.disable_static()
tensor_x = paddle.to_tensor(x)
tensor_x.stop_gradient = False
out = logsumexp_wrapper(tensor_x, axis, keepdim, reduce_all)
grad = paddle.grad(out, [tensor_x])
x_grad = grad[0].numpy()
paddle.enable_static()
return x_grad
def logsumexp_ref_grad(x):
sum = np.exp(x).sum()
return np.exp(x) / sum
class TestLogsumexp(OpTest): class TestLogsumexp(OpTest):
def setUp(self): def setUp(self):
...@@ -125,6 +142,47 @@ class TestLogsumexp_reduce_all(TestLogsumexp): ...@@ -125,6 +142,47 @@ class TestLogsumexp_reduce_all(TestLogsumexp):
self.user_defined_grad_outputs = [np.ones(1, dtype=self.dtype)] self.user_defined_grad_outputs = [np.ones(1, dtype=self.dtype)]
class TestLogsumexp_FP32(TestLogsumexp):
def set_attrs(self):
self.dtype = 'float32'
def test_check_grad(self):
self.__class__.dtype = self.dtype
x_grad = logsumexp_op_grad(self.inputs['X'])
ref_x_grad = logsumexp_ref_grad(self.inputs['X'])
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-08, atol=1e-08)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestLogsumexp_FP16(TestLogsumexp):
def set_attrs(self):
self.dtype = 'float16'
def test_check_output(self):
ref_x = self.inputs['X'].astype(np.float32)
out_ref = ref_logsumexp(ref_x)
paddle.disable_static()
x = self.inputs['X'].astype(np.float16)
tensor_x = paddle.to_tensor(x)
out_pad = logsumexp_wrapper(tensor_x)
paddle.enable_static()
np.testing.assert_allclose(out_pad.numpy(),
out_ref,
rtol=1e-03,
atol=1e-08)
def test_check_grad(self):
self.__class__.dtype = self.dtype
ref_x = self.inputs['X'].astype(np.float32)
ref_x_grad = logsumexp_ref_grad(ref_x)
x = self.inputs['X'].astype(np.float16)
x_grad = logsumexp_op_grad(x)
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-03, atol=1e-05)
class TestLogsumexpError(unittest.TestCase): class TestLogsumexpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册