未验证 提交 468869e4 编写于 作者: C cyberslack_lee 提交者: GitHub

【Hackathon4 No58】logcumsum logsum (#51275)

上级 43efb979
...@@ -461,5 +461,6 @@ PD_REGISTER_KERNEL(logcumsumexp, ...@@ -461,5 +461,6 @@ PD_REGISTER_KERNEL(logcumsumexp,
phi::LogcumsumexpKernel, phi::LogcumsumexpKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif #endif
...@@ -34,5 +34,6 @@ PD_REGISTER_KERNEL(logcumsumexp_grad, ...@@ -34,5 +34,6 @@ PD_REGISTER_KERNEL(logcumsumexp_grad,
phi::LogcumsumexpGradKernel, phi::LogcumsumexpGradKernel,
phi::dtype::float16, phi::dtype::float16,
float, float,
double) {} double,
phi::dtype::bfloat16) {}
#endif #endif
...@@ -19,12 +19,11 @@ ...@@ -19,12 +19,11 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL(logsumexp_grad, PD_REGISTER_KERNEL(logsumexp_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LogsumexpGradKernel, phi::LogsumexpGradKernel,
float, float,
double, double,
float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/logsumexp_kernel.h" #include "paddle/phi/kernels/logsumexp_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h"
...@@ -22,8 +23,6 @@ ...@@ -22,8 +23,6 @@
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/gpu/reduce.h"
using float16 = phi::dtype::float16;
namespace phi { namespace phi {
template <typename T> template <typename T>
...@@ -39,6 +38,14 @@ struct LogCUDAFunctor<float16> { ...@@ -39,6 +38,14 @@ struct LogCUDAFunctor<float16> {
} }
}; };
template <>
struct LogCUDAFunctor<bfloat16> {
HOSTDEVICE inline bfloat16 operator()(const bfloat16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<bfloat16>(std::log(x_));
}
};
template <typename T, typename Context> template <typename T, typename Context>
void LogsumexpKernel(const Context& dev_ctx, void LogsumexpKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -112,5 +119,11 @@ void LogsumexpKernel(const Context& dev_ctx, ...@@ -112,5 +119,11 @@ void LogsumexpKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(logsumexp,
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double, float16) {} GPU,
ALL_LAYOUT,
phi::LogsumexpKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
......
...@@ -17,7 +17,11 @@ import unittest ...@@ -17,7 +17,11 @@ import unittest
from typing import Optional from typing import Optional
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -314,5 +318,48 @@ class TestLogcumsumexpFP16(unittest.TestCase): ...@@ -314,5 +318,48 @@ class TestLogcumsumexpFP16(unittest.TestCase):
np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=2e-03) np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=2e-03)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestLogcumsumexpBF16Op(OpTest):
def setUp(self):
self.op_type = 'logcumsumexp'
self.dtype = np.uint16
self.python_api = logcumsumexp_wrapper
x = np.arange(100, dtype=np.float64).reshape(10, 10)
output = np_logcumsumexp(x)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self):
place = core.CUDAPlace(0)
place = core.CUDAPlace(0)
self.check_output_with_place_customized(
checker=self.verify_output, place=place
)
def verify_output(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, (10, 10))
hist, _ = np.histogram(outs[0], range=(-3, 5))
hist = hist.astype("float64")
hist /= float(outs[0].size)
x = np.arange(100, dtype=np.float64).reshape(10, 10)
data = np_logcumsumexp(x)
hist2, _ = np.histogram(data, range=(-3, 5))
hist2 = hist2.astype("float64")
hist2 /= float(outs[0].size)
np.testing.assert_allclose(hist, hist2, rtol=0.3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', numeric_grad_delta=0.5, max_relative_error=0.5
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -184,6 +184,47 @@ class TestLogsumexp_FP16(TestLogsumexp): ...@@ -184,6 +184,47 @@ class TestLogsumexp_FP16(TestLogsumexp):
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-03, atol=1e-05) np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-03, atol=1e-05)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestLogsumexpBF16Op(TestLogsumexp):
def setUp(self):
self.op_type = 'logsumexp'
self.python_api = logsumexp_wrapper
self.dtype = np.uint16
self.shape = [2, 3, 4, 5]
self.axis = [-1]
self.keepdim = False
self.reduce_all = False
self.set_attrs()
x = np.random.uniform(-1, 1, self.shape).astype(np.float64)
out = ref_logsumexp(x, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
self.attrs = {
'axis': self.axis,
'keepdim': self.keepdim,
'reduce_all': self.reduce_all,
}
self.set_attrs_addition()
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def set_attrs(self):
pass
def set_attrs_addition(self):
pass
class TestLogsumexpError(unittest.TestCase): class TestLogsumexpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
......
...@@ -2203,7 +2203,7 @@ def logsumexp(x, axis=None, keepdim=False, name=None): ...@@ -2203,7 +2203,7 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
return _C_ops.logsumexp(x, axis, keepdim, reduce_all) return _C_ops.logsumexp(x, axis, keepdim, reduce_all)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'logsumexp' x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'logsumexp'
) )
helper = LayerHelper('logsumexp', **locals()) helper = LayerHelper('logsumexp', **locals())
...@@ -3351,7 +3351,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): ...@@ -3351,7 +3351,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None):
return _C_ops.logcumsumexp(x, axis, flatten, False, False) return _C_ops.logcumsumexp(x, axis, flatten, False, False)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], "logcumsumexp" x, 'x', ['float16', 'float32', 'float64', 'uint16'], "logcumsumexp"
) )
helper = LayerHelper('logcumsumexp', **locals()) helper = LayerHelper('logcumsumexp', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册