未验证 提交 468869e4 编写于 作者: C cyberslack_lee 提交者: GitHub

【Hackathon4 No58】logcumsum logsum (#51275)

上级 43efb979
......@@ -461,5 +461,6 @@ PD_REGISTER_KERNEL(logcumsumexp,
phi::LogcumsumexpKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
......@@ -34,5 +34,6 @@ PD_REGISTER_KERNEL(logcumsumexp_grad,
phi::LogcumsumexpGradKernel,
phi::dtype::float16,
float,
double) {}
double,
phi::dtype::bfloat16) {}
#endif
......@@ -19,12 +19,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL(logsumexp_grad,
GPU,
ALL_LAYOUT,
phi::LogsumexpGradKernel,
float,
double,
float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/logsumexp_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
......@@ -22,8 +23,6 @@
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/gpu/reduce.h"
using float16 = phi::dtype::float16;
namespace phi {
template <typename T>
......@@ -39,6 +38,14 @@ struct LogCUDAFunctor<float16> {
}
};
template <>
struct LogCUDAFunctor<bfloat16> {
HOSTDEVICE inline bfloat16 operator()(const bfloat16 x) const {
auto x_ = static_cast<float>(x);
return static_cast<bfloat16>(std::log(x_));
}
};
template <typename T, typename Context>
void LogsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -112,5 +119,11 @@ void LogsumexpKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double, float16) {}
PD_REGISTER_KERNEL(logsumexp,
GPU,
ALL_LAYOUT,
phi::LogsumexpKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
......
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
......
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
......
......@@ -17,7 +17,11 @@ import unittest
from typing import Optional
import numpy as np
from eager_op_test import OpTest
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
import paddle
from paddle import fluid
......@@ -314,5 +318,48 @@ class TestLogcumsumexpFP16(unittest.TestCase):
np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=2e-03)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestLogcumsumexpBF16Op(OpTest):
def setUp(self):
self.op_type = 'logcumsumexp'
self.dtype = np.uint16
self.python_api = logcumsumexp_wrapper
x = np.arange(100, dtype=np.float64).reshape(10, 10)
output = np_logcumsumexp(x)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self):
place = core.CUDAPlace(0)
place = core.CUDAPlace(0)
self.check_output_with_place_customized(
checker=self.verify_output, place=place
)
def verify_output(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, (10, 10))
hist, _ = np.histogram(outs[0], range=(-3, 5))
hist = hist.astype("float64")
hist /= float(outs[0].size)
x = np.arange(100, dtype=np.float64).reshape(10, 10)
data = np_logcumsumexp(x)
hist2, _ = np.histogram(data, range=(-3, 5))
hist2 = hist2.astype("float64")
hist2 /= float(outs[0].size)
np.testing.assert_allclose(hist, hist2, rtol=0.3)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', numeric_grad_delta=0.5, max_relative_error=0.5
)
if __name__ == '__main__':
unittest.main()
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -184,6 +184,47 @@ class TestLogsumexp_FP16(TestLogsumexp):
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-03, atol=1e-05)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestLogsumexpBF16Op(TestLogsumexp):
def setUp(self):
self.op_type = 'logsumexp'
self.python_api = logsumexp_wrapper
self.dtype = np.uint16
self.shape = [2, 3, 4, 5]
self.axis = [-1]
self.keepdim = False
self.reduce_all = False
self.set_attrs()
x = np.random.uniform(-1, 1, self.shape).astype(np.float64)
out = ref_logsumexp(x, self.axis, self.keepdim, self.reduce_all)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
self.attrs = {
'axis': self.axis,
'keepdim': self.keepdim,
'reduce_all': self.reduce_all,
}
self.set_attrs_addition()
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def set_attrs(self):
pass
def set_attrs_addition(self):
pass
class TestLogsumexpError(unittest.TestCase):
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
......
......@@ -2203,7 +2203,7 @@ def logsumexp(x, axis=None, keepdim=False, name=None):
return _C_ops.logsumexp(x, axis, keepdim, reduce_all)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'logsumexp'
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'logsumexp'
)
helper = LayerHelper('logsumexp', **locals())
......@@ -3351,7 +3351,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None):
return _C_ops.logcumsumexp(x, axis, flatten, False, False)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], "logcumsumexp"
x, 'x', ['float16', 'float32', 'float64', 'uint16'], "logcumsumexp"
)
helper = LayerHelper('logcumsumexp', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册